From 6b0a7afe235399790c066dd725c437403a47a73e Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Tue, 16 Apr 2024 15:27:06 -0500 Subject: [PATCH 01/19] core/services/functions: switch to sqlutil.DataStore (#12811) --- .changeset/real-numbers-taste.md | 5 + core/services/chainlink/application.go | 1 + core/services/functions/listener.go | 15 +- core/services/functions/listener_test.go | 24 +-- core/services/functions/mocks/orm.go | 193 ++++++------------ core/services/functions/orm.go | 98 ++++----- core/services/functions/orm_test.go | 100 +++++---- .../handlers/functions/allowlist/allowlist.go | 14 +- .../functions/allowlist/allowlist_test.go | 28 +-- .../handlers/functions/allowlist/mocks/orm.go | 80 +++----- .../handlers/functions/allowlist/orm.go | 37 ++-- .../handlers/functions/allowlist/orm_test.go | 81 ++++---- .../handlers/functions/handler.functions.go | 4 +- .../functions/subscriptions/mocks/orm.go | 45 ++-- .../handlers/functions/subscriptions/orm.go | 26 +-- .../functions/subscriptions/orm_test.go | 55 ++--- .../functions/subscriptions/subscriptions.go | 8 +- .../subscriptions/subscriptions_test.go | 14 +- core/services/job/spawner_test.go | 2 +- core/services/ocr2/delegate.go | 9 +- .../services/ocr2/plugins/functions/plugin.go | 14 +- .../ocr2/plugins/functions/reporting.go | 11 +- .../ocr2/plugins/functions/reporting_test.go | 44 ++-- .../ocr2/plugins/s4/integration_test.go | 53 +++-- core/services/ocr2/plugins/s4/plugin.go | 13 +- core/services/ocr2/plugins/s4/plugin_test.go | 20 +- core/services/s4/cached_orm_wrapper.go | 22 +- core/services/s4/cached_orm_wrapper_test.go | 69 ++++--- core/services/s4/in_memory_orm.go | 12 +- core/services/s4/in_memory_orm_test.go | 34 +-- core/services/s4/mocks/orm.go | 122 ++++------- core/services/s4/orm.go | 12 +- core/services/s4/postgres_orm.go | 42 ++-- core/services/s4/postgres_orm_test.go | 71 ++++--- core/services/s4/storage.go | 11 +- core/services/s4/storage_test.go | 6 +- 36 files changed, 652 insertions(+), 743 deletions(-) create mode 100644 .changeset/real-numbers-taste.md diff --git a/.changeset/real-numbers-taste.md b/.changeset/real-numbers-taste.md new file mode 100644 index 00000000000..d9f545444c2 --- /dev/null +++ b/.changeset/real-numbers-taste.md @@ -0,0 +1,5 @@ +--- +"chainlink": patch +--- + +core/services/functions: switch to sqlutil.DataStore #internal diff --git a/core/services/chainlink/application.go b/core/services/chainlink/application.go index edc613e25dd..6c373846205 100644 --- a/core/services/chainlink/application.go +++ b/core/services/chainlink/application.go @@ -437,6 +437,7 @@ func NewApplication(opts ApplicationOpts) (Application, error) { delegates[job.OffchainReporting2] = ocr2.NewDelegate( sqlxDB, + opts.DB, jobORM, bridgeORM, mercuryORM, diff --git a/core/services/functions/listener.go b/core/services/functions/listener.go index ff4e268573a..d2033ff74de 100644 --- a/core/services/functions/listener.go +++ b/core/services/functions/listener.go @@ -23,7 +23,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/config" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/threshold" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" evmrelayTypes "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/types" "github.com/smartcontractkit/chainlink/v2/core/services/s4" "github.com/smartcontractkit/chainlink/v2/core/services/synchronization/telem" @@ -270,7 +269,7 @@ func (l *functionsListener) setError(ctx context.Context, requestId RequestID, e promRequestComputationError.WithLabelValues(l.contractAddressHex).Inc() } readyForProcessing := errType != INTERNAL_ERROR - if err := l.pluginORM.SetError(requestId, errType, errBytes, time.Now(), readyForProcessing, pg.WithParentCtx(ctx)); err != nil { + if err := l.pluginORM.SetError(ctx, requestId, errType, errBytes, time.Now(), readyForProcessing); err != nil { l.logger.Errorw("call to SetError failed", "requestID", formatRequestId(requestId), "err", err) } } @@ -321,7 +320,7 @@ func (l *functionsListener) HandleOffchainRequest(ctx context.Context, request * CoordinatorContractAddress: &senderAddr, OnchainMetadata: []byte(OffchainRequestMarker), } - if err := l.pluginORM.CreateRequest(newReq, pg.WithParentCtx(ctx)); err != nil { + if err := l.pluginORM.CreateRequest(ctx, newReq); err != nil { if errors.Is(err, ErrDuplicateRequestID) { l.logger.Warnw("HandleOffchainRequest: received duplicate request ID", "requestID", formatRequestId(requestId), "err", err) } else { @@ -348,7 +347,7 @@ func (l *functionsListener) handleOracleRequestV1(request *evmrelayTypes.OracleR CoordinatorContractAddress: &request.CoordinatorContract, OnchainMetadata: request.OnchainMetadata, } - if err := l.pluginORM.CreateRequest(newReq, pg.WithParentCtx(ctx)); err != nil { + if err := l.pluginORM.CreateRequest(ctx, newReq); err != nil { if errors.Is(err, ErrDuplicateRequestID) { l.logger.Warnw("handleOracleRequestV1: received a log with duplicate request ID", "requestID", formatRequestId(request.RequestId), "err", err) } else { @@ -450,7 +449,7 @@ func (l *functionsListener) handleRequest(ctx context.Context, requestID Request promRequestComputationSuccess.WithLabelValues(l.contractAddressHex).Inc() promComputationResultSize.WithLabelValues(l.contractAddressHex).Set(float64(len(computationResult))) l.logger.Debugw("saving computation result", "requestID", requestIDStr) - if err2 := l.pluginORM.SetResult(requestID, computationResult, time.Now(), pg.WithParentCtx(ctx)); err2 != nil { + if err2 := l.pluginORM.SetResult(ctx, requestID, computationResult, time.Now()); err2 != nil { l.logger.Errorw("call to SetResult failed", "requestID", requestIDStr, "err", err2) return err2 } @@ -464,7 +463,7 @@ func (l *functionsListener) handleOracleResponseV1(response *evmrelayTypes.Oracl ctx, cancel := l.getNewHandlerContext() defer cancel() - if err := l.pluginORM.SetConfirmed(response.RequestId, pg.WithParentCtx(ctx)); err != nil { + if err := l.pluginORM.SetConfirmed(ctx, response.RequestId); err != nil { l.logger.Errorw("setting CONFIRMED state failed", "requestID", formatRequestId(response.RequestId), "err", err) } promRequestConfirmed.WithLabelValues(l.contractAddressHex).Inc() @@ -486,7 +485,7 @@ func (l *functionsListener) timeoutRequests() { case <-ticker.C: cutoff := time.Now().Add(-(time.Duration(timeoutSec) * time.Second)) ctx, cancel := l.getNewHandlerContext() - ids, err := l.pluginORM.TimeoutExpiredResults(cutoff, batchSize, pg.WithParentCtx(ctx)) + ids, err := l.pluginORM.TimeoutExpiredResults(ctx, cutoff, batchSize) cancel() if err != nil { l.logger.Errorw("error when calling FindExpiredResults", "err", err) @@ -531,7 +530,7 @@ func (l *functionsListener) pruneRequests() { case <-ticker.C: ctx, cancel := l.getNewHandlerContext() startTime := time.Now() - nTotal, nPruned, err := l.pluginORM.PruneOldestRequests(maxStoredRequests, batchSize, pg.WithParentCtx(ctx)) + nTotal, nPruned, err := l.pluginORM.PruneOldestRequests(ctx, maxStoredRequests, batchSize) cancel() elapsedMillis := time.Since(startTime).Milliseconds() if err != nil { diff --git a/core/services/functions/listener_test.go b/core/services/functions/listener_test.go index 090ced7c91d..d6cd9aa23d6 100644 --- a/core/services/functions/listener_test.go +++ b/core/services/functions/listener_test.go @@ -1,6 +1,7 @@ package functions_test import ( + "context" "encoding/json" "errors" "fmt" @@ -35,7 +36,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/config" threshold_mocks "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/threshold/mocks" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" evmrelay "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/types" @@ -172,7 +172,7 @@ func TestFunctionsListener_HandleOracleRequestV1_Success(t *testing.T) { uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything).Return(nil) uni.bridgeAccessor.On("NewExternalAdapterClient", mock.Anything).Return(uni.eaClient, nil) uni.eaClient.On("RunComputation", mock.Anything, RequestIDStr, mock.Anything, SubscriptionOwner.Hex(), SubscriptionID, mock.Anything, mock.Anything, mock.Anything).Return(ResultBytes, nil, nil, nil) - uni.pluginORM.On("SetResult", RequestID, ResultBytes, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + uni.pluginORM.On("SetResult", mock.Anything, RequestID, ResultBytes, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { close(doneCh) }).Return(nil) @@ -189,7 +189,7 @@ func TestFunctionsListener_HandleOffchainRequest_Success(t *testing.T) { uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything).Return(nil) uni.bridgeAccessor.On("NewExternalAdapterClient", mock.Anything).Return(uni.eaClient, nil) uni.eaClient.On("RunComputation", mock.Anything, RequestIDStr, mock.Anything, SubscriptionOwner.Hex(), SubscriptionID, mock.Anything, mock.Anything, mock.Anything).Return(ResultBytes, nil, nil, nil) - uni.pluginORM.On("SetResult", RequestID, ResultBytes, mock.Anything, mock.Anything).Return(nil) + uni.pluginORM.On("SetResult", mock.Anything, RequestID, ResultBytes, mock.Anything, mock.Anything).Return(nil) request := &functions_service.OffchainRequest{ RequestId: RequestID[:], @@ -233,7 +233,7 @@ func TestFunctionsListener_HandleOffchainRequest_InternalError(t *testing.T) { uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything).Return(nil) uni.bridgeAccessor.On("NewExternalAdapterClient", mock.Anything).Return(uni.eaClient, nil) uni.eaClient.On("RunComputation", mock.Anything, RequestIDStr, mock.Anything, SubscriptionOwner.Hex(), SubscriptionID, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil, nil, errors.New("error")) - uni.pluginORM.On("SetError", RequestID, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + uni.pluginORM.On("SetError", mock.Anything, RequestID, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) request := &functions_service.OffchainRequest{ RequestId: RequestID[:], @@ -266,7 +266,7 @@ func TestFunctionsListener_HandleOracleRequestV1_ComputationError(t *testing.T) uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything).Return(nil) uni.bridgeAccessor.On("NewExternalAdapterClient", mock.Anything).Return(uni.eaClient, nil) uni.eaClient.On("RunComputation", mock.Anything, RequestIDStr, mock.Anything, SubscriptionOwner.Hex(), SubscriptionID, mock.Anything, mock.Anything, mock.Anything).Return(nil, ErrorBytes, nil, nil) - uni.pluginORM.On("SetError", RequestID, mock.Anything, ErrorBytes, mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + uni.pluginORM.On("SetError", mock.Anything, RequestID, mock.Anything, ErrorBytes, mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { close(doneCh) }).Return(nil) @@ -307,7 +307,7 @@ func TestFunctionsListener_HandleOracleRequestV1_ThresholdDecryptedSecrets(t *te uni.eaClient.On("FetchEncryptedSecrets", mock.Anything, mock.Anything, RequestIDStr, mock.Anything, mock.Anything).Return(EncryptedSecrets, nil, nil) uni.decryptor.On("Decrypt", mock.Anything, decryptionPlugin.CiphertextId(RequestID[:]), EncryptedSecrets).Return(DecryptedSecrets, nil) uni.eaClient.On("RunComputation", mock.Anything, RequestIDStr, mock.Anything, SubscriptionOwner.Hex(), SubscriptionID, mock.Anything, mock.Anything, mock.Anything).Return(ResultBytes, nil, nil, nil) - uni.pluginORM.On("SetResult", RequestID, ResultBytes, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + uni.pluginORM.On("SetResult", mock.Anything, RequestID, ResultBytes, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { close(doneCh) }).Return(nil) @@ -333,7 +333,7 @@ func TestFunctionsListener_HandleOracleRequestV1_CBORTooBig(t *testing.T) { uni.logPollerWrapper.On("LatestEvents", mock.Anything).Return([]types.OracleRequest{request}, nil, nil).Once() uni.logPollerWrapper.On("LatestEvents", mock.Anything).Return(nil, nil, nil) uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything).Return(nil) - uni.pluginORM.On("SetError", RequestID, functions_service.USER_ERROR, []byte("request too big (max 10 bytes)"), mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + uni.pluginORM.On("SetError", mock.Anything, RequestID, functions_service.USER_ERROR, []byte("request too big (max 10 bytes)"), mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { close(doneCh) }).Return(nil) @@ -361,7 +361,7 @@ func TestFunctionsListener_ReportSourceCodeDomains(t *testing.T) { uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything).Return(nil) uni.bridgeAccessor.On("NewExternalAdapterClient", mock.Anything).Return(uni.eaClient, nil) uni.eaClient.On("RunComputation", mock.Anything, RequestIDStr, mock.Anything, SubscriptionOwner.Hex(), SubscriptionID, mock.Anything, mock.Anything, mock.Anything).Return(ResultBytes, nil, Domains, nil) - uni.pluginORM.On("SetResult", RequestID, ResultBytes, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + uni.pluginORM.On("SetResult", mock.Anything, RequestID, ResultBytes, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { close(doneCh) }).Return(nil) var sentMessage []byte @@ -388,7 +388,7 @@ func TestFunctionsListener_PruneRequests(t *testing.T) { uni := NewFunctionsListenerUniverse(t, 0, 1) doneCh := make(chan bool) uni.logPollerWrapper.On("LatestEvents", mock.Anything).Return(nil, nil, nil) - uni.pluginORM.On("PruneOldestRequests", functions_service.DefaultPruneMaxStoredRequests, functions_service.DefaultPruneBatchSize, mock.Anything).Return(uint32(0), uint32(0), nil).Run(func(args mock.Arguments) { + uni.pluginORM.On("PruneOldestRequests", mock.Anything, functions_service.DefaultPruneMaxStoredRequests, functions_service.DefaultPruneBatchSize, mock.Anything).Return(uint32(0), uint32(0), nil).Run(func(args mock.Arguments) { doneCh <- true }) @@ -403,7 +403,7 @@ func TestFunctionsListener_TimeoutRequests(t *testing.T) { uni := NewFunctionsListenerUniverse(t, 1, 0) doneCh := make(chan bool) uni.logPollerWrapper.On("LatestEvents", mock.Anything).Return(nil, nil, nil) - uni.pluginORM.On("TimeoutExpiredResults", mock.Anything, uint32(1), mock.Anything).Return([]functions_service.RequestID{}, nil).Run(func(args mock.Arguments) { + uni.pluginORM.On("TimeoutExpiredResults", mock.Anything, mock.Anything, uint32(1), mock.Anything).Return([]functions_service.RequestID{}, nil).Run(func(args mock.Arguments) { doneCh <- true }) @@ -423,9 +423,7 @@ func TestFunctionsListener_ORMDoesNotFreezeHandlersForever(t *testing.T) { uni.logPollerWrapper.On("LatestEvents", mock.Anything).Return([]types.OracleRequest{request}, nil, nil).Once() uni.logPollerWrapper.On("LatestEvents", mock.Anything).Return(nil, nil, nil) uni.pluginORM.On("CreateRequest", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { - var queryerWrapper pg.Q - args.Get(1).(pg.QOpt)(&queryerWrapper) - <-queryerWrapper.ParentCtx.Done() + <-args.Get(0).(context.Context).Done() ormCallExited.Done() }).Return(errors.New("timeout")) diff --git a/core/services/functions/mocks/orm.go b/core/services/functions/mocks/orm.go index 90055fe6286..ff72916171b 100644 --- a/core/services/functions/mocks/orm.go +++ b/core/services/functions/mocks/orm.go @@ -3,11 +3,11 @@ package mocks import ( + context "context" + functions "github.com/smartcontractkit/chainlink/v2/core/services/functions" mock "github.com/stretchr/testify/mock" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" - time "time" ) @@ -16,24 +16,17 @@ type ORM struct { mock.Mock } -// CreateRequest provides a mock function with given fields: request, qopts -func (_m *ORM) CreateRequest(request *functions.Request, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, request) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// CreateRequest provides a mock function with given fields: ctx, request +func (_m *ORM) CreateRequest(ctx context.Context, request *functions.Request) error { + ret := _m.Called(ctx, request) if len(ret) == 0 { panic("no return value specified for CreateRequest") } var r0 error - if rf, ok := ret.Get(0).(func(*functions.Request, ...pg.QOpt) error); ok { - r0 = rf(request, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *functions.Request) error); ok { + r0 = rf(ctx, request) } else { r0 = ret.Error(0) } @@ -41,16 +34,9 @@ func (_m *ORM) CreateRequest(request *functions.Request, qopts ...pg.QOpt) error return r0 } -// FindById provides a mock function with given fields: requestID, qopts -func (_m *ORM) FindById(requestID functions.RequestID, qopts ...pg.QOpt) (*functions.Request, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, requestID) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// FindById provides a mock function with given fields: ctx, requestID +func (_m *ORM) FindById(ctx context.Context, requestID functions.RequestID) (*functions.Request, error) { + ret := _m.Called(ctx, requestID) if len(ret) == 0 { panic("no return value specified for FindById") @@ -58,19 +44,19 @@ func (_m *ORM) FindById(requestID functions.RequestID, qopts ...pg.QOpt) (*funct var r0 *functions.Request var r1 error - if rf, ok := ret.Get(0).(func(functions.RequestID, ...pg.QOpt) (*functions.Request, error)); ok { - return rf(requestID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, functions.RequestID) (*functions.Request, error)); ok { + return rf(ctx, requestID) } - if rf, ok := ret.Get(0).(func(functions.RequestID, ...pg.QOpt) *functions.Request); ok { - r0 = rf(requestID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, functions.RequestID) *functions.Request); ok { + r0 = rf(ctx, requestID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*functions.Request) } } - if rf, ok := ret.Get(1).(func(functions.RequestID, ...pg.QOpt) error); ok { - r1 = rf(requestID, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, functions.RequestID) error); ok { + r1 = rf(ctx, requestID) } else { r1 = ret.Error(1) } @@ -78,16 +64,9 @@ func (_m *ORM) FindById(requestID functions.RequestID, qopts ...pg.QOpt) (*funct return r0, r1 } -// FindOldestEntriesByState provides a mock function with given fields: state, limit, qopts -func (_m *ORM) FindOldestEntriesByState(state functions.RequestState, limit uint32, qopts ...pg.QOpt) ([]functions.Request, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, state, limit) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// FindOldestEntriesByState provides a mock function with given fields: ctx, state, limit +func (_m *ORM) FindOldestEntriesByState(ctx context.Context, state functions.RequestState, limit uint32) ([]functions.Request, error) { + ret := _m.Called(ctx, state, limit) if len(ret) == 0 { panic("no return value specified for FindOldestEntriesByState") @@ -95,19 +74,19 @@ func (_m *ORM) FindOldestEntriesByState(state functions.RequestState, limit uint var r0 []functions.Request var r1 error - if rf, ok := ret.Get(0).(func(functions.RequestState, uint32, ...pg.QOpt) ([]functions.Request, error)); ok { - return rf(state, limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, functions.RequestState, uint32) ([]functions.Request, error)); ok { + return rf(ctx, state, limit) } - if rf, ok := ret.Get(0).(func(functions.RequestState, uint32, ...pg.QOpt) []functions.Request); ok { - r0 = rf(state, limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, functions.RequestState, uint32) []functions.Request); ok { + r0 = rf(ctx, state, limit) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]functions.Request) } } - if rf, ok := ret.Get(1).(func(functions.RequestState, uint32, ...pg.QOpt) error); ok { - r1 = rf(state, limit, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, functions.RequestState, uint32) error); ok { + r1 = rf(ctx, state, limit) } else { r1 = ret.Error(1) } @@ -115,16 +94,9 @@ func (_m *ORM) FindOldestEntriesByState(state functions.RequestState, limit uint return r0, r1 } -// PruneOldestRequests provides a mock function with given fields: maxRequestsInDB, batchSize, qopts -func (_m *ORM) PruneOldestRequests(maxRequestsInDB uint32, batchSize uint32, qopts ...pg.QOpt) (uint32, uint32, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, maxRequestsInDB, batchSize) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// PruneOldestRequests provides a mock function with given fields: ctx, maxRequestsInDB, batchSize +func (_m *ORM) PruneOldestRequests(ctx context.Context, maxRequestsInDB uint32, batchSize uint32) (uint32, uint32, error) { + ret := _m.Called(ctx, maxRequestsInDB, batchSize) if len(ret) == 0 { panic("no return value specified for PruneOldestRequests") @@ -133,23 +105,23 @@ func (_m *ORM) PruneOldestRequests(maxRequestsInDB uint32, batchSize uint32, qop var r0 uint32 var r1 uint32 var r2 error - if rf, ok := ret.Get(0).(func(uint32, uint32, ...pg.QOpt) (uint32, uint32, error)); ok { - return rf(maxRequestsInDB, batchSize, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint32, uint32) (uint32, uint32, error)); ok { + return rf(ctx, maxRequestsInDB, batchSize) } - if rf, ok := ret.Get(0).(func(uint32, uint32, ...pg.QOpt) uint32); ok { - r0 = rf(maxRequestsInDB, batchSize, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint32, uint32) uint32); ok { + r0 = rf(ctx, maxRequestsInDB, batchSize) } else { r0 = ret.Get(0).(uint32) } - if rf, ok := ret.Get(1).(func(uint32, uint32, ...pg.QOpt) uint32); ok { - r1 = rf(maxRequestsInDB, batchSize, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, uint32, uint32) uint32); ok { + r1 = rf(ctx, maxRequestsInDB, batchSize) } else { r1 = ret.Get(1).(uint32) } - if rf, ok := ret.Get(2).(func(uint32, uint32, ...pg.QOpt) error); ok { - r2 = rf(maxRequestsInDB, batchSize, qopts...) + if rf, ok := ret.Get(2).(func(context.Context, uint32, uint32) error); ok { + r2 = rf(ctx, maxRequestsInDB, batchSize) } else { r2 = ret.Error(2) } @@ -157,24 +129,17 @@ func (_m *ORM) PruneOldestRequests(maxRequestsInDB uint32, batchSize uint32, qop return r0, r1, r2 } -// SetConfirmed provides a mock function with given fields: requestID, qopts -func (_m *ORM) SetConfirmed(requestID functions.RequestID, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, requestID) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// SetConfirmed provides a mock function with given fields: ctx, requestID +func (_m *ORM) SetConfirmed(ctx context.Context, requestID functions.RequestID) error { + ret := _m.Called(ctx, requestID) if len(ret) == 0 { panic("no return value specified for SetConfirmed") } var r0 error - if rf, ok := ret.Get(0).(func(functions.RequestID, ...pg.QOpt) error); ok { - r0 = rf(requestID, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, functions.RequestID) error); ok { + r0 = rf(ctx, requestID) } else { r0 = ret.Error(0) } @@ -182,24 +147,17 @@ func (_m *ORM) SetConfirmed(requestID functions.RequestID, qopts ...pg.QOpt) err return r0 } -// SetError provides a mock function with given fields: requestID, errorType, computationError, readyAt, readyForProcessing, qopts -func (_m *ORM) SetError(requestID functions.RequestID, errorType functions.ErrType, computationError []byte, readyAt time.Time, readyForProcessing bool, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, requestID, errorType, computationError, readyAt, readyForProcessing) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// SetError provides a mock function with given fields: ctx, requestID, errorType, computationError, readyAt, readyForProcessing +func (_m *ORM) SetError(ctx context.Context, requestID functions.RequestID, errorType functions.ErrType, computationError []byte, readyAt time.Time, readyForProcessing bool) error { + ret := _m.Called(ctx, requestID, errorType, computationError, readyAt, readyForProcessing) if len(ret) == 0 { panic("no return value specified for SetError") } var r0 error - if rf, ok := ret.Get(0).(func(functions.RequestID, functions.ErrType, []byte, time.Time, bool, ...pg.QOpt) error); ok { - r0 = rf(requestID, errorType, computationError, readyAt, readyForProcessing, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, functions.RequestID, functions.ErrType, []byte, time.Time, bool) error); ok { + r0 = rf(ctx, requestID, errorType, computationError, readyAt, readyForProcessing) } else { r0 = ret.Error(0) } @@ -207,24 +165,17 @@ func (_m *ORM) SetError(requestID functions.RequestID, errorType functions.ErrTy return r0 } -// SetFinalized provides a mock function with given fields: requestID, reportedResult, reportedError, qopts -func (_m *ORM) SetFinalized(requestID functions.RequestID, reportedResult []byte, reportedError []byte, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, requestID, reportedResult, reportedError) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// SetFinalized provides a mock function with given fields: ctx, requestID, reportedResult, reportedError +func (_m *ORM) SetFinalized(ctx context.Context, requestID functions.RequestID, reportedResult []byte, reportedError []byte) error { + ret := _m.Called(ctx, requestID, reportedResult, reportedError) if len(ret) == 0 { panic("no return value specified for SetFinalized") } var r0 error - if rf, ok := ret.Get(0).(func(functions.RequestID, []byte, []byte, ...pg.QOpt) error); ok { - r0 = rf(requestID, reportedResult, reportedError, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, functions.RequestID, []byte, []byte) error); ok { + r0 = rf(ctx, requestID, reportedResult, reportedError) } else { r0 = ret.Error(0) } @@ -232,24 +183,17 @@ func (_m *ORM) SetFinalized(requestID functions.RequestID, reportedResult []byte return r0 } -// SetResult provides a mock function with given fields: requestID, computationResult, readyAt, qopts -func (_m *ORM) SetResult(requestID functions.RequestID, computationResult []byte, readyAt time.Time, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, requestID, computationResult, readyAt) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// SetResult provides a mock function with given fields: ctx, requestID, computationResult, readyAt +func (_m *ORM) SetResult(ctx context.Context, requestID functions.RequestID, computationResult []byte, readyAt time.Time) error { + ret := _m.Called(ctx, requestID, computationResult, readyAt) if len(ret) == 0 { panic("no return value specified for SetResult") } var r0 error - if rf, ok := ret.Get(0).(func(functions.RequestID, []byte, time.Time, ...pg.QOpt) error); ok { - r0 = rf(requestID, computationResult, readyAt, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, functions.RequestID, []byte, time.Time) error); ok { + r0 = rf(ctx, requestID, computationResult, readyAt) } else { r0 = ret.Error(0) } @@ -257,16 +201,9 @@ func (_m *ORM) SetResult(requestID functions.RequestID, computationResult []byte return r0 } -// TimeoutExpiredResults provides a mock function with given fields: cutoff, limit, qopts -func (_m *ORM) TimeoutExpiredResults(cutoff time.Time, limit uint32, qopts ...pg.QOpt) ([]functions.RequestID, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, cutoff, limit) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// TimeoutExpiredResults provides a mock function with given fields: ctx, cutoff, limit +func (_m *ORM) TimeoutExpiredResults(ctx context.Context, cutoff time.Time, limit uint32) ([]functions.RequestID, error) { + ret := _m.Called(ctx, cutoff, limit) if len(ret) == 0 { panic("no return value specified for TimeoutExpiredResults") @@ -274,19 +211,19 @@ func (_m *ORM) TimeoutExpiredResults(cutoff time.Time, limit uint32, qopts ...pg var r0 []functions.RequestID var r1 error - if rf, ok := ret.Get(0).(func(time.Time, uint32, ...pg.QOpt) ([]functions.RequestID, error)); ok { - return rf(cutoff, limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, time.Time, uint32) ([]functions.RequestID, error)); ok { + return rf(ctx, cutoff, limit) } - if rf, ok := ret.Get(0).(func(time.Time, uint32, ...pg.QOpt) []functions.RequestID); ok { - r0 = rf(cutoff, limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, time.Time, uint32) []functions.RequestID); ok { + r0 = rf(ctx, cutoff, limit) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]functions.RequestID) } } - if rf, ok := ret.Get(1).(func(time.Time, uint32, ...pg.QOpt) error); ok { - r1 = rf(cutoff, limit, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, time.Time, uint32) error); ok { + r1 = rf(ctx, cutoff, limit) } else { r1 = ret.Error(1) } diff --git a/core/services/functions/orm.go b/core/services/functions/orm.go index 7838c700858..f45effa9354 100644 --- a/core/services/functions/orm.go +++ b/core/services/functions/orm.go @@ -1,38 +1,37 @@ package functions import ( + "context" "fmt" "time" "github.com/ethereum/go-ethereum/common" - "github.com/pkg/errors" - "github.com/jmoiron/sqlx" + "github.com/pkg/errors" - "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" ) //go:generate mockery --quiet --name ORM --output ./mocks/ --case=underscore type ORM interface { - CreateRequest(request *Request, qopts ...pg.QOpt) error + CreateRequest(ctx context.Context, request *Request) error - SetResult(requestID RequestID, computationResult []byte, readyAt time.Time, qopts ...pg.QOpt) error - SetError(requestID RequestID, errorType ErrType, computationError []byte, readyAt time.Time, readyForProcessing bool, qopts ...pg.QOpt) error - SetFinalized(requestID RequestID, reportedResult []byte, reportedError []byte, qopts ...pg.QOpt) error - SetConfirmed(requestID RequestID, qopts ...pg.QOpt) error + SetResult(ctx context.Context, requestID RequestID, computationResult []byte, readyAt time.Time) error + SetError(ctx context.Context, requestID RequestID, errorType ErrType, computationError []byte, readyAt time.Time, readyForProcessing bool) error + SetFinalized(ctx context.Context, requestID RequestID, reportedResult []byte, reportedError []byte) error + SetConfirmed(ctx context.Context, requestID RequestID) error - TimeoutExpiredResults(cutoff time.Time, limit uint32, qopts ...pg.QOpt) ([]RequestID, error) + TimeoutExpiredResults(ctx context.Context, cutoff time.Time, limit uint32) ([]RequestID, error) - FindOldestEntriesByState(state RequestState, limit uint32, qopts ...pg.QOpt) ([]Request, error) - FindById(requestID RequestID, qopts ...pg.QOpt) (*Request, error) + FindOldestEntriesByState(ctx context.Context, state RequestState, limit uint32) ([]Request, error) + FindById(ctx context.Context, requestID RequestID) (*Request, error) - PruneOldestRequests(maxRequestsInDB uint32, batchSize uint32, qopts ...pg.QOpt) (total uint32, pruned uint32, err error) + PruneOldestRequests(ctx context.Context, maxRequestsInDB uint32, batchSize uint32) (total uint32, pruned uint32, err error) } type orm struct { - q pg.Q + ds sqlutil.DataSource contractAddress common.Address } @@ -49,19 +48,20 @@ const ( "callback_gas_limit, coordinator_contract_address, onchain_metadata, processing_metadata" ) -func NewORM(db *sqlx.DB, lggr logger.Logger, cfg pg.QConfig, contractAddress common.Address) ORM { +func NewORM(ds sqlutil.DataSource, contractAddress common.Address) ORM { return &orm{ - q: pg.NewQ(db, lggr, cfg), + ds: ds, contractAddress: contractAddress, } } -func (o *orm) CreateRequest(request *Request, qopts ...pg.QOpt) error { +func (o *orm) CreateRequest(ctx context.Context, request *Request) error { stmt := fmt.Sprintf(` INSERT INTO %s (request_id, contract_address, received_at, request_tx_hash, state, flags, aggregation_method, callback_gas_limit, coordinator_contract_address, onchain_metadata) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10) ON CONFLICT (request_id) DO NOTHING; `, tableName) - result, err := o.q.WithOpts(qopts...).Exec( + result, err := o.ds.ExecContext( + ctx, stmt, request.RequestID, o.contractAddress, @@ -86,11 +86,11 @@ func (o *orm) CreateRequest(request *Request, qopts ...pg.QOpt) error { return nil } -func (o *orm) setWithStateTransitionCheck(requestID RequestID, newState RequestState, setter func(pg.Queryer) error, qopts ...pg.QOpt) error { - err := o.q.WithOpts(qopts...).Transaction(func(tx pg.Queryer) error { +func (o *orm) setWithStateTransitionCheck(ctx context.Context, requestID RequestID, newState RequestState, setter func(sqlutil.DataSource) error) error { + err := sqlutil.TransactDataSource(ctx, o.ds, nil, func(tx sqlutil.DataSource) error { prevState := defaultInitialState stmt := fmt.Sprintf(`SELECT state FROM %s WHERE request_id=$1 AND contract_address=$2;`, tableName) - if err2 := tx.Get(&prevState, stmt, requestID, o.contractAddress); err2 != nil { + if err2 := tx.GetContext(ctx, &prevState, stmt, requestID, o.contractAddress); err2 != nil { return err2 } if err2 := CheckStateTransition(prevState, newState); err2 != nil { @@ -102,64 +102,64 @@ func (o *orm) setWithStateTransitionCheck(requestID RequestID, newState RequestS return err } -func (o *orm) SetResult(requestID RequestID, computationResult []byte, readyAt time.Time, qopts ...pg.QOpt) error { +func (o *orm) SetResult(ctx context.Context, requestID RequestID, computationResult []byte, readyAt time.Time) error { newState := RESULT_READY - err := o.setWithStateTransitionCheck(requestID, newState, func(tx pg.Queryer) error { + err := o.setWithStateTransitionCheck(ctx, requestID, newState, func(tx sqlutil.DataSource) error { stmt := fmt.Sprintf(` UPDATE %s SET result=$3, result_ready_at=$4, state=$5 WHERE request_id=$1 AND contract_address=$2; `, tableName) - _, err2 := tx.Exec(stmt, requestID, o.contractAddress, computationResult, readyAt, newState) + _, err2 := tx.ExecContext(ctx, stmt, requestID, o.contractAddress, computationResult, readyAt, newState) return err2 - }, qopts...) + }) return err } -func (o *orm) SetError(requestID RequestID, errorType ErrType, computationError []byte, readyAt time.Time, readyForProcessing bool, qopts ...pg.QOpt) error { +func (o *orm) SetError(ctx context.Context, requestID RequestID, errorType ErrType, computationError []byte, readyAt time.Time, readyForProcessing bool) error { var newState RequestState if readyForProcessing { newState = RESULT_READY } else { newState = IN_PROGRESS } - err := o.setWithStateTransitionCheck(requestID, newState, func(tx pg.Queryer) error { + err := o.setWithStateTransitionCheck(ctx, requestID, newState, func(tx sqlutil.DataSource) error { stmt := fmt.Sprintf(` UPDATE %s SET error=$3, error_type=$4, result_ready_at=$5, state=$6 WHERE request_id=$1 AND contract_address=$2; `, tableName) - _, err2 := tx.Exec(stmt, requestID, o.contractAddress, computationError, errorType, readyAt, newState) + _, err2 := tx.ExecContext(ctx, stmt, requestID, o.contractAddress, computationError, errorType, readyAt, newState) return err2 - }, qopts...) + }) return err } -func (o *orm) SetFinalized(requestID RequestID, reportedResult []byte, reportedError []byte, qopts ...pg.QOpt) error { +func (o *orm) SetFinalized(ctx context.Context, requestID RequestID, reportedResult []byte, reportedError []byte) error { newState := FINALIZED - err := o.setWithStateTransitionCheck(requestID, newState, func(tx pg.Queryer) error { + err := o.setWithStateTransitionCheck(ctx, requestID, newState, func(tx sqlutil.DataSource) error { stmt := fmt.Sprintf(` UPDATE %s SET transmitted_result=$3, transmitted_error=$4, state=$5 WHERE request_id=$1 AND contract_address=$2; `, tableName) - _, err2 := tx.Exec(stmt, requestID, o.contractAddress, reportedResult, reportedError, newState) + _, err2 := tx.ExecContext(ctx, stmt, requestID, o.contractAddress, reportedResult, reportedError, newState) return err2 - }, qopts...) + }) return err } -func (o *orm) SetConfirmed(requestID RequestID, qopts ...pg.QOpt) error { +func (o *orm) SetConfirmed(ctx context.Context, requestID RequestID) error { newState := CONFIRMED - err := o.setWithStateTransitionCheck(requestID, newState, func(tx pg.Queryer) error { + err := o.setWithStateTransitionCheck(ctx, requestID, newState, func(tx sqlutil.DataSource) error { stmt := fmt.Sprintf(`UPDATE %s SET state=$3 WHERE request_id=$1 AND contract_address=$2;`, tableName) - _, err2 := tx.Exec(stmt, requestID, o.contractAddress, newState) + _, err2 := tx.ExecContext(ctx, stmt, requestID, o.contractAddress, newState) return err2 - }, qopts...) + }) return err } -func (o *orm) TimeoutExpiredResults(cutoff time.Time, limit uint32, qopts ...pg.QOpt) ([]RequestID, error) { +func (o *orm) TimeoutExpiredResults(ctx context.Context, cutoff time.Time, limit uint32) ([]RequestID, error) { var ids []RequestID allowedPrevStates := []RequestState{IN_PROGRESS, RESULT_READY, FINALIZED} nextState := TIMED_OUT @@ -169,14 +169,14 @@ func (o *orm) TimeoutExpiredResults(cutoff time.Time, limit uint32, qopts ...pg. return ids, err } } - err := o.q.WithOpts(qopts...).Transaction(func(tx pg.Queryer) error { + err := sqlutil.TransactDataSource(ctx, o.ds, nil, func(tx sqlutil.DataSource) error { selectStmt := fmt.Sprintf(` SELECT request_id FROM %s WHERE (state=$1 OR state=$2 OR state=$3) AND contract_address=$4 AND received_at < ($5) ORDER BY received_at LIMIT $6;`, tableName) - if err2 := tx.Select(&ids, selectStmt, allowedPrevStates[0], allowedPrevStates[1], allowedPrevStates[2], o.contractAddress, cutoff, limit); err2 != nil { + if err2 := tx.SelectContext(ctx, &ids, selectStmt, allowedPrevStates[0], allowedPrevStates[1], allowedPrevStates[2], o.contractAddress, cutoff, limit); err2 != nil { return err2 } if len(ids) == 0 { @@ -200,7 +200,7 @@ func (o *orm) TimeoutExpiredResults(cutoff time.Time, limit uint32, qopts ...pg. return err2 } updateStmt = tx.Rebind(updateStmt) - if _, err2 := tx.Exec(updateStmt, args...); err2 != nil { + if _, err2 := tx.ExecContext(ctx, updateStmt, args...); err2 != nil { return err2 } return nil @@ -209,28 +209,28 @@ func (o *orm) TimeoutExpiredResults(cutoff time.Time, limit uint32, qopts ...pg. return ids, err } -func (o *orm) FindOldestEntriesByState(state RequestState, limit uint32, qopts ...pg.QOpt) ([]Request, error) { +func (o *orm) FindOldestEntriesByState(ctx context.Context, state RequestState, limit uint32) ([]Request, error) { var requests []Request stmt := fmt.Sprintf(`SELECT %s FROM %s WHERE state=$1 AND contract_address=$2 ORDER BY received_at LIMIT $3;`, requestFields, tableName) - if err := o.q.WithOpts(qopts...).Select(&requests, stmt, state, o.contractAddress, limit); err != nil { + if err := o.ds.SelectContext(ctx, &requests, stmt, state, o.contractAddress, limit); err != nil { return nil, err } return requests, nil } -func (o *orm) FindById(requestID RequestID, qopts ...pg.QOpt) (*Request, error) { +func (o *orm) FindById(ctx context.Context, requestID RequestID) (*Request, error) { var request Request stmt := fmt.Sprintf(`SELECT %s FROM %s WHERE request_id=$1 AND contract_address=$2;`, requestFields, tableName) - if err := o.q.WithOpts(qopts...).Get(&request, stmt, requestID, o.contractAddress); err != nil { + if err := o.ds.GetContext(ctx, &request, stmt, requestID, o.contractAddress); err != nil { return nil, err } return &request, nil } -func (o *orm) PruneOldestRequests(maxStoredRequests uint32, batchSize uint32, qopts ...pg.QOpt) (total uint32, pruned uint32, err error) { - err = o.q.WithOpts(qopts...).Transaction(func(tx pg.Queryer) error { +func (o *orm) PruneOldestRequests(ctx context.Context, maxStoredRequests uint32, batchSize uint32) (total uint32, pruned uint32, err error) { + err = sqlutil.TransactDataSource(ctx, o.ds, nil, func(tx sqlutil.DataSource) error { stmt := fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE contract_address=$1`, tableName) - if err2 := tx.Get(&total, stmt, o.contractAddress); err2 != nil { + if err2 := tx.GetContext(ctx, &total, stmt, o.contractAddress); err2 != nil { return errors.Wrap(err, "failed to get request count") } @@ -246,7 +246,7 @@ func (o *orm) PruneOldestRequests(maxStoredRequests uint32, batchSize uint32, qo with := fmt.Sprintf(`WITH ids AS (SELECT request_id FROM %s WHERE contract_address = $1 ORDER BY received_at LIMIT $2)`, tableName) deleteStmt := fmt.Sprintf(`%s DELETE FROM %s WHERE contract_address = $1 AND request_id IN (SELECT request_id FROM ids);`, with, tableName) - res, err2 := tx.Exec(deleteStmt, o.contractAddress, pruneLimit) + res, err2 := tx.ExecContext(ctx, deleteStmt, o.contractAddress, pruneLimit) if err2 != nil { return err2 } diff --git a/core/services/functions/orm_test.go b/core/services/functions/orm_test.go index ca92aafcb0e..37b3a28256f 100644 --- a/core/services/functions/orm_test.go +++ b/core/services/functions/orm_test.go @@ -11,7 +11,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" - "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/functions" ) @@ -28,9 +27,8 @@ func setupORM(t *testing.T) functions.ORM { var ( db = pgtest.NewSqlxDB(t) - lggr = logger.TestLogger(t) contract = testutils.NewAddress() - orm = functions.NewORM(db, lggr, pgtest.NewQConfig(true), contract) + orm = functions.NewORM(db, contract) ) return orm @@ -47,6 +45,7 @@ func createRequest(t *testing.T, orm functions.ORM) (functions.RequestID, common } func createRequestWithTimestamp(t *testing.T, orm functions.ORM, ts time.Time) (functions.RequestID, common.Hash) { + ctx := testutils.Context(t) id := newRequestID() txHash := utils.RandomHash() newReq := &functions.Request{ @@ -59,19 +58,20 @@ func createRequestWithTimestamp(t *testing.T, orm functions.ORM, ts time.Time) ( CoordinatorContractAddress: &defaultCoordinatorContract, OnchainMetadata: defaultMetadata, } - err := orm.CreateRequest(newReq) + err := orm.CreateRequest(ctx, newReq) require.NoError(t, err) return id, txHash } func TestORM_CreateRequestsAndFindByID(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) id1, txHash1, ts1 := createRequest(t, orm) id2, txHash2, ts2 := createRequest(t, orm) - req1, err := orm.FindById(id1) + req1, err := orm.FindById(ctx, id1) require.NoError(t, err) require.Equal(t, id1, req1.RequestID) require.Equal(t, &txHash1, req1.RequestTxHash) @@ -83,7 +83,7 @@ func TestORM_CreateRequestsAndFindByID(t *testing.T) { require.Equal(t, defaultCoordinatorContract, *req1.CoordinatorContractAddress) require.Equal(t, defaultMetadata, req1.OnchainMetadata) - req2, err := orm.FindById(id2) + req2, err := orm.FindById(ctx, id2) require.NoError(t, err) require.Equal(t, id2, req2.RequestID) require.Equal(t, &txHash2, req2.RequestTxHash) @@ -91,14 +91,14 @@ func TestORM_CreateRequestsAndFindByID(t *testing.T) { require.Equal(t, functions.IN_PROGRESS, req2.State) t.Run("missing ID", func(t *testing.T) { - req, err := orm.FindById(newRequestID()) + req, err := orm.FindById(testutils.Context(t), newRequestID()) require.Error(t, err) require.Nil(t, req) }) t.Run("duplicated", func(t *testing.T) { newReq := &functions.Request{RequestID: id1, RequestTxHash: &txHash1, ReceivedAt: ts1} - err := orm.CreateRequest(newReq) + err := orm.CreateRequest(testutils.Context(t), newReq) require.Error(t, err) require.True(t, errors.Is(err, functions.ErrDuplicateRequestID)) }) @@ -106,15 +106,16 @@ func TestORM_CreateRequestsAndFindByID(t *testing.T) { func TestORM_SetResult(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) id, _, ts := createRequest(t, orm) rdts := time.Now().Round(time.Second) - err := orm.SetResult(id, []byte("result"), rdts) + err := orm.SetResult(ctx, id, []byte("result"), rdts) require.NoError(t, err) - req, err := orm.FindById(id) + req, err := orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, id, req.RequestID) require.Equal(t, ts, req.ReceivedAt) @@ -126,15 +127,16 @@ func TestORM_SetResult(t *testing.T) { func TestORM_SetError(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) id, _, ts := createRequest(t, orm) rdts := time.Now().Round(time.Second) - err := orm.SetError(id, functions.USER_ERROR, []byte("error"), rdts, true) + err := orm.SetError(ctx, id, functions.USER_ERROR, []byte("error"), rdts, true) require.NoError(t, err) - req, err := orm.FindById(id) + req, err := orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, id, req.RequestID) require.Equal(t, ts, req.ReceivedAt) @@ -148,15 +150,16 @@ func TestORM_SetError(t *testing.T) { func TestORM_SetError_Internal(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) id, _, ts := createRequest(t, orm) rdts := time.Now().Round(time.Second) - err := orm.SetError(id, functions.INTERNAL_ERROR, []byte("error"), rdts, false) + err := orm.SetError(ctx, id, functions.INTERNAL_ERROR, []byte("error"), rdts, false) require.NoError(t, err) - req, err := orm.FindById(id) + req, err := orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, id, req.RequestID) require.Equal(t, ts, req.ReceivedAt) @@ -167,14 +170,15 @@ func TestORM_SetError_Internal(t *testing.T) { func TestORM_SetFinalized(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) id, _, _ := createRequest(t, orm) - err := orm.SetFinalized(id, []byte("result"), []byte("error")) + err := orm.SetFinalized(ctx, id, []byte("result"), []byte("error")) require.NoError(t, err) - req, err := orm.FindById(id) + req, err := orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, []byte("result"), req.TransmittedResult) require.Equal(t, []byte("error"), req.TransmittedError) @@ -183,49 +187,51 @@ func TestORM_SetFinalized(t *testing.T) { func TestORM_SetConfirmed(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) id, _, _ := createRequest(t, orm) - err := orm.SetConfirmed(id) + err := orm.SetConfirmed(ctx, id) require.NoError(t, err) - req, err := orm.FindById(id) + req, err := orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, functions.CONFIRMED, req.State) } func TestORM_StateTransitions(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) now := time.Now() id, _ := createRequestWithTimestamp(t, orm, now) - req, err := orm.FindById(id) + req, err := orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, functions.IN_PROGRESS, req.State) - err = orm.SetResult(id, []byte{}, now) + err = orm.SetResult(ctx, id, []byte{}, now) require.NoError(t, err) - req, err = orm.FindById(id) + req, err = orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, functions.RESULT_READY, req.State) - _, err = orm.TimeoutExpiredResults(now.Add(time.Minute), 1) + _, err = orm.TimeoutExpiredResults(ctx, now.Add(time.Minute), 1) require.NoError(t, err) - req, err = orm.FindById(id) + req, err = orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, functions.TIMED_OUT, req.State) - err = orm.SetFinalized(id, nil, nil) + err = orm.SetFinalized(ctx, id, nil, nil) require.Error(t, err) - req, err = orm.FindById(id) + req, err = orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, functions.TIMED_OUT, req.State) - err = orm.SetConfirmed(id) + err = orm.SetConfirmed(ctx, id) require.NoError(t, err) - req, err = orm.FindById(id) + req, err = orm.FindById(ctx, id) require.NoError(t, err) require.Equal(t, functions.CONFIRMED, req.State) } @@ -240,7 +246,8 @@ func TestORM_FindOldestEntriesByState(t *testing.T) { id1, _ := createRequestWithTimestamp(t, orm, now.Add(1*time.Minute)) t.Run("with limit", func(t *testing.T) { - result, err := orm.FindOldestEntriesByState(functions.IN_PROGRESS, 2) + ctx := testutils.Context(t) + result, err := orm.FindOldestEntriesByState(ctx, functions.IN_PROGRESS, 2) require.NoError(t, err) require.Equal(t, 2, len(result), "incorrect results length") require.Equal(t, id1, result[0].RequestID, "incorrect results order") @@ -255,13 +262,15 @@ func TestORM_FindOldestEntriesByState(t *testing.T) { }) t.Run("with no limit", func(t *testing.T) { - result, err := orm.FindOldestEntriesByState(functions.IN_PROGRESS, 20) + ctx := testutils.Context(t) + result, err := orm.FindOldestEntriesByState(ctx, functions.IN_PROGRESS, 20) require.NoError(t, err) require.Equal(t, 3, len(result), "incorrect results length") }) t.Run("no matching entries", func(t *testing.T) { - result, err := orm.FindOldestEntriesByState(functions.RESULT_READY, 10) + ctx := testutils.Context(t) + result, err := orm.FindOldestEntriesByState(ctx, functions.RESULT_READY, 10) require.NoError(t, err) require.Equal(t, 0, len(result), "incorrect results length") }) @@ -269,6 +278,7 @@ func TestORM_FindOldestEntriesByState(t *testing.T) { func TestORM_TimeoutExpiredResults(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) now := time.Now() @@ -278,26 +288,26 @@ func TestORM_TimeoutExpiredResults(t *testing.T) { ids = append(ids, id) } // can time out IN_PROGRESS, RESULT_READY or FINALIZED - err := orm.SetResult(ids[0], []byte("result"), now) + err := orm.SetResult(ctx, ids[0], []byte("result"), now) require.NoError(t, err) - err = orm.SetFinalized(ids[1], []byte("result"), []byte("")) + err = orm.SetFinalized(ctx, ids[1], []byte("result"), []byte("")) require.NoError(t, err) // can't time out CONFIRMED - err = orm.SetConfirmed(ids[2]) + err = orm.SetConfirmed(ctx, ids[2]) require.NoError(t, err) - results, err := orm.TimeoutExpiredResults(now.Add(-35*time.Minute), 1) + results, err := orm.TimeoutExpiredResults(ctx, now.Add(-35*time.Minute), 1) require.NoError(t, err) require.Equal(t, 1, len(results), "not respecting limit") require.Equal(t, ids[0], results[0], "incorrect results order") - results, err = orm.TimeoutExpiredResults(now.Add(-15*time.Minute), 10) + results, err = orm.TimeoutExpiredResults(ctx, now.Add(-15*time.Minute), 10) require.NoError(t, err) require.Equal(t, 2, len(results), "incorrect results length") require.Equal(t, ids[1], results[0], "incorrect results order") require.Equal(t, ids[3], results[1], "incorrect results order") - results, err = orm.TimeoutExpiredResults(now.Add(-15*time.Minute), 10) + results, err = orm.TimeoutExpiredResults(ctx, now.Add(-15*time.Minute), 10) require.NoError(t, err) require.Equal(t, 0, len(results), "not idempotent") @@ -309,7 +319,7 @@ func TestORM_TimeoutExpiredResults(t *testing.T) { functions.IN_PROGRESS, } for i, expectedState := range expectedFinalStates { - req, err := orm.FindById(ids[i]) + req, err := orm.FindById(ctx, ids[i]) require.NoError(t, err) require.Equal(t, req.State, expectedState, "incorrect state") } @@ -317,6 +327,7 @@ func TestORM_TimeoutExpiredResults(t *testing.T) { func TestORM_PruneOldestRequests(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) now := time.Now() @@ -328,31 +339,31 @@ func TestORM_PruneOldestRequests(t *testing.T) { } // don't prune if max not hit - total, pruned, err := orm.PruneOldestRequests(6, 3) + total, pruned, err := orm.PruneOldestRequests(ctx, 6, 3) require.NoError(t, err) require.Equal(t, uint32(5), total) require.Equal(t, uint32(0), pruned) // prune up to max batch size - total, pruned, err = orm.PruneOldestRequests(1, 2) + total, pruned, err = orm.PruneOldestRequests(ctx, 1, 2) require.NoError(t, err) require.Equal(t, uint32(5), total) require.Equal(t, uint32(2), pruned) // prune all above the limit - total, pruned, err = orm.PruneOldestRequests(1, 20) + total, pruned, err = orm.PruneOldestRequests(ctx, 1, 20) require.NoError(t, err) require.Equal(t, uint32(3), total) require.Equal(t, uint32(2), pruned) // no pruning needed any more - total, pruned, err = orm.PruneOldestRequests(1, 20) + total, pruned, err = orm.PruneOldestRequests(ctx, 1, 20) require.NoError(t, err) require.Equal(t, uint32(1), total) require.Equal(t, uint32(0), pruned) // verify only the newest one is left after pruning - result, err := orm.FindOldestEntriesByState(functions.IN_PROGRESS, 20) + result, err := orm.FindOldestEntriesByState(ctx, functions.IN_PROGRESS, 20) require.NoError(t, err) require.Equal(t, 1, len(result), "incorrect results length") require.Equal(t, ids[4], result[0].RequestID, "incorrect results order") @@ -360,6 +371,7 @@ func TestORM_PruneOldestRequests(t *testing.T) { func TestORM_PruneOldestRequests_Large(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t) now := time.Now() @@ -369,13 +381,13 @@ func TestORM_PruneOldestRequests_Large(t *testing.T) { } // prune 900/1000 - total, pruned, err := orm.PruneOldestRequests(100, 1000) + total, pruned, err := orm.PruneOldestRequests(ctx, 100, 1000) require.NoError(t, err) require.Equal(t, uint32(1000), total) require.Equal(t, uint32(900), pruned) // verify there's 100 left - result, err := orm.FindOldestEntriesByState(functions.IN_PROGRESS, 200) + result, err := orm.FindOldestEntriesByState(ctx, functions.IN_PROGRESS, 200) require.NoError(t, err) require.Equal(t, 100, len(result), "incorrect results length") } diff --git a/core/services/gateway/handlers/functions/allowlist/allowlist.go b/core/services/gateway/handlers/functions/allowlist/allowlist.go index 020de2359c2..f0fe5c8c829 100644 --- a/core/services/gateway/handlers/functions/allowlist/allowlist.go +++ b/core/services/gateway/handlers/functions/allowlist/allowlist.go @@ -128,7 +128,7 @@ func (a *onchainAllowlist) Start(ctx context.Context) error { return nil } - a.loadStoredAllowedSenderList() + a.loadStoredAllowedSenderList(ctx) updateOnce := func() { timeoutCtx, cancel := utils.ContextFromChanWithTimeout(a.stopCh, time.Duration(a.config.UpdateTimeoutSec)*time.Second) @@ -245,12 +245,12 @@ func (a *onchainAllowlist) updateFromContractV1(ctx context.Context, blockNum *b return errors.Wrap(err, "error calling GetAllAllowedSenders") } - err = a.orm.PurgeAllowedSenders() + err = a.orm.PurgeAllowedSenders(ctx) if err != nil { a.lggr.Errorf("failed to purge allowedSenderList: %w", err) } - err = a.orm.CreateAllowedSenders(allowedSenderList) + err = a.orm.CreateAllowedSenders(ctx, allowedSenderList) if err != nil { a.lggr.Errorf("failed to update stored allowedSenderList: %w", err) } @@ -290,7 +290,7 @@ func (a *onchainAllowlist) getAllowedSendersBatched(ctx context.Context, tosCont } allowedSenderList = append(allowedSenderList, allowedSendersBatch...) - err = a.orm.CreateAllowedSenders(allowedSendersBatch) + err = a.orm.CreateAllowedSenders(ctx, allowedSendersBatch) if err != nil { a.lggr.Errorf("failed to update stored allowedSenderList: %w", err) } @@ -330,7 +330,7 @@ func (a *onchainAllowlist) syncBlockedSenders(ctx context.Context, tosContract * return errors.Wrap(err, "error calling GetAllowedSendersInRange") } - err = a.orm.DeleteAllowedSenders(blockedSendersBatch) + err = a.orm.DeleteAllowedSenders(ctx, blockedSendersBatch) if err != nil { a.lggr.Errorf("failed to delete blocked address from allowed list in storage: %w", err) } @@ -349,11 +349,11 @@ func (a *onchainAllowlist) update(addrList []common.Address) { a.lggr.Infow("allowlist updated successfully", "len", len(addrList)) } -func (a *onchainAllowlist) loadStoredAllowedSenderList() { +func (a *onchainAllowlist) loadStoredAllowedSenderList(ctx context.Context) { allowedList := make([]common.Address, 0) offset := uint(0) for { - asBatch, err := a.orm.GetAllowedSenders(offset, a.config.StoredAllowlistBatchSize) + asBatch, err := a.orm.GetAllowedSenders(ctx, offset, a.config.StoredAllowlistBatchSize) if err != nil { a.lggr.Errorf("failed to get stored allowed senders: %w", err) break diff --git a/core/services/gateway/handlers/functions/allowlist/allowlist_test.go b/core/services/gateway/handlers/functions/allowlist/allowlist_test.go index 735c0bff7dc..d4900627bdb 100644 --- a/core/services/gateway/handlers/functions/allowlist/allowlist_test.go +++ b/core/services/gateway/handlers/functions/allowlist/allowlist_test.go @@ -58,8 +58,8 @@ func TestUpdateAndCheck(t *testing.T) { } orm := amocks.NewORM(t) - orm.On("PurgeAllowedSenders").Times(1).Return(nil) - orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + orm.On("PurgeAllowedSenders", mock.Anything).Times(1).Return(nil) + orm.On("CreateAllowedSenders", mock.Anything, []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) @@ -99,8 +99,8 @@ func TestUpdateAndCheck(t *testing.T) { } orm := amocks.NewORM(t) - orm.On("DeleteAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) - orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + orm.On("DeleteAllowedSenders", mock.Anything, []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + orm.On("CreateAllowedSenders", mock.Anything, []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) @@ -163,9 +163,9 @@ func TestUpdatePeriodically(t *testing.T) { } orm := amocks.NewORM(t) - orm.On("PurgeAllowedSenders").Times(1).Return(nil) - orm.On("GetAllowedSenders", uint(0), uint(1000)).Return([]common.Address{}, nil) - orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + orm.On("PurgeAllowedSenders", mock.Anything).Times(1).Return(nil) + orm.On("GetAllowedSenders", mock.Anything, uint(0), uint(1000)).Return([]common.Address{}, nil) + orm.On("CreateAllowedSenders", mock.Anything, []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) @@ -207,9 +207,9 @@ func TestUpdatePeriodically(t *testing.T) { } orm := amocks.NewORM(t) - orm.On("DeleteAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) - orm.On("GetAllowedSenders", uint(0), uint(1000)).Return([]common.Address{}, nil) - orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + orm.On("DeleteAllowedSenders", mock.Anything, []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + orm.On("GetAllowedSenders", mock.Anything, uint(0), uint(1000)).Return([]common.Address{}, nil) + orm.On("CreateAllowedSenders", mock.Anything, []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) @@ -258,8 +258,8 @@ func TestUpdateFromContract(t *testing.T) { } orm := amocks.NewORM(t) - orm.On("PurgeAllowedSenders").Times(1).Return(nil) - orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) + orm.On("PurgeAllowedSenders", mock.Anything).Times(1).Return(nil) + orm.On("CreateAllowedSenders", mock.Anything, []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(1).Return(nil) allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) @@ -301,8 +301,8 @@ func TestUpdateFromContract(t *testing.T) { } orm := amocks.NewORM(t) - orm.On("DeleteAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(2).Return(nil) - orm.On("CreateAllowedSenders", []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(2).Return(nil) + orm.On("DeleteAllowedSenders", mock.Anything, []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(2).Return(nil) + orm.On("CreateAllowedSenders", mock.Anything, []common.Address{common.HexToAddress(addr1), common.HexToAddress(addr2)}).Times(2).Return(nil) allowlist, err := allowlist.NewOnchainAllowlist(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) diff --git a/core/services/gateway/handlers/functions/allowlist/mocks/orm.go b/core/services/gateway/handlers/functions/allowlist/mocks/orm.go index daff33d8902..76121270518 100644 --- a/core/services/gateway/handlers/functions/allowlist/mocks/orm.go +++ b/core/services/gateway/handlers/functions/allowlist/mocks/orm.go @@ -3,10 +3,11 @@ package mocks import ( + context "context" + common "github.com/ethereum/go-ethereum/common" - mock "github.com/stretchr/testify/mock" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" + mock "github.com/stretchr/testify/mock" ) // ORM is an autogenerated mock type for the ORM type @@ -14,24 +15,17 @@ type ORM struct { mock.Mock } -// CreateAllowedSenders provides a mock function with given fields: allowedSenders, qopts -func (_m *ORM) CreateAllowedSenders(allowedSenders []common.Address, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, allowedSenders) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// CreateAllowedSenders provides a mock function with given fields: ctx, allowedSenders +func (_m *ORM) CreateAllowedSenders(ctx context.Context, allowedSenders []common.Address) error { + ret := _m.Called(ctx, allowedSenders) if len(ret) == 0 { panic("no return value specified for CreateAllowedSenders") } var r0 error - if rf, ok := ret.Get(0).(func([]common.Address, ...pg.QOpt) error); ok { - r0 = rf(allowedSenders, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, []common.Address) error); ok { + r0 = rf(ctx, allowedSenders) } else { r0 = ret.Error(0) } @@ -39,24 +33,17 @@ func (_m *ORM) CreateAllowedSenders(allowedSenders []common.Address, qopts ...pg return r0 } -// DeleteAllowedSenders provides a mock function with given fields: blockedSenders, qopts -func (_m *ORM) DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, blockedSenders) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// DeleteAllowedSenders provides a mock function with given fields: ctx, blockedSenders +func (_m *ORM) DeleteAllowedSenders(ctx context.Context, blockedSenders []common.Address) error { + ret := _m.Called(ctx, blockedSenders) if len(ret) == 0 { panic("no return value specified for DeleteAllowedSenders") } var r0 error - if rf, ok := ret.Get(0).(func([]common.Address, ...pg.QOpt) error); ok { - r0 = rf(blockedSenders, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, []common.Address) error); ok { + r0 = rf(ctx, blockedSenders) } else { r0 = ret.Error(0) } @@ -64,16 +51,9 @@ func (_m *ORM) DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg return r0 } -// GetAllowedSenders provides a mock function with given fields: offset, limit, qopts -func (_m *ORM) GetAllowedSenders(offset uint, limit uint, qopts ...pg.QOpt) ([]common.Address, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, offset, limit) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// GetAllowedSenders provides a mock function with given fields: ctx, offset, limit +func (_m *ORM) GetAllowedSenders(ctx context.Context, offset uint, limit uint) ([]common.Address, error) { + ret := _m.Called(ctx, offset, limit) if len(ret) == 0 { panic("no return value specified for GetAllowedSenders") @@ -81,19 +61,19 @@ func (_m *ORM) GetAllowedSenders(offset uint, limit uint, qopts ...pg.QOpt) ([]c var r0 []common.Address var r1 error - if rf, ok := ret.Get(0).(func(uint, uint, ...pg.QOpt) ([]common.Address, error)); ok { - return rf(offset, limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint, uint) ([]common.Address, error)); ok { + return rf(ctx, offset, limit) } - if rf, ok := ret.Get(0).(func(uint, uint, ...pg.QOpt) []common.Address); ok { - r0 = rf(offset, limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint, uint) []common.Address); ok { + r0 = rf(ctx, offset, limit) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]common.Address) } } - if rf, ok := ret.Get(1).(func(uint, uint, ...pg.QOpt) error); ok { - r1 = rf(offset, limit, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, uint, uint) error); ok { + r1 = rf(ctx, offset, limit) } else { r1 = ret.Error(1) } @@ -101,23 +81,17 @@ func (_m *ORM) GetAllowedSenders(offset uint, limit uint, qopts ...pg.QOpt) ([]c return r0, r1 } -// PurgeAllowedSenders provides a mock function with given fields: qopts -func (_m *ORM) PurgeAllowedSenders(qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// PurgeAllowedSenders provides a mock function with given fields: ctx +func (_m *ORM) PurgeAllowedSenders(ctx context.Context) error { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for PurgeAllowedSenders") } var r0 error - if rf, ok := ret.Get(0).(func(...pg.QOpt) error); ok { - r0 = rf(qopts...) + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) } else { r0 = ret.Error(0) } diff --git a/core/services/gateway/handlers/functions/allowlist/orm.go b/core/services/gateway/handlers/functions/allowlist/orm.go index ccacec81a43..7867c06d5d4 100644 --- a/core/services/gateway/handlers/functions/allowlist/orm.go +++ b/core/services/gateway/handlers/functions/allowlist/orm.go @@ -1,28 +1,27 @@ package allowlist import ( + "context" "fmt" "strings" "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" - "github.com/jmoiron/sqlx" - + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) //go:generate mockery --quiet --name ORM --output ./mocks/ --case=underscore type ORM interface { - GetAllowedSenders(offset, limit uint, qopts ...pg.QOpt) ([]common.Address, error) - CreateAllowedSenders(allowedSenders []common.Address, qopts ...pg.QOpt) error - DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg.QOpt) error - PurgeAllowedSenders(qopts ...pg.QOpt) error + GetAllowedSenders(ctx context.Context, offset, limit uint) ([]common.Address, error) + CreateAllowedSenders(ctx context.Context, allowedSenders []common.Address) error + DeleteAllowedSenders(ctx context.Context, blockedSenders []common.Address) error + PurgeAllowedSenders(ctx context.Context) error } type orm struct { - q pg.Q + ds sqlutil.DataSource lggr logger.Logger routerContractAddress common.Address } @@ -36,19 +35,19 @@ const ( tableName = "functions_allowlist" ) -func NewORM(db *sqlx.DB, lggr logger.Logger, cfg pg.QConfig, routerContractAddress common.Address) (ORM, error) { - if db == nil || cfg == nil || lggr == nil || routerContractAddress == (common.Address{}) { +func NewORM(ds sqlutil.DataSource, lggr logger.Logger, routerContractAddress common.Address) (ORM, error) { + if ds == nil || lggr == nil || routerContractAddress == (common.Address{}) { return nil, ErrInvalidParameters } return &orm{ - q: pg.NewQ(db, lggr, cfg), + ds: ds, lggr: lggr, routerContractAddress: routerContractAddress, }, nil } -func (o *orm) GetAllowedSenders(offset, limit uint, qopts ...pg.QOpt) ([]common.Address, error) { +func (o *orm) GetAllowedSenders(ctx context.Context, offset, limit uint) ([]common.Address, error) { var addresses []common.Address stmt := fmt.Sprintf(` SELECT allowed_address @@ -58,7 +57,7 @@ func (o *orm) GetAllowedSenders(offset, limit uint, qopts ...pg.QOpt) ([]common. OFFSET $2 LIMIT $3; `, tableName) - err := o.q.WithOpts(qopts...).Select(&addresses, stmt, o.routerContractAddress, offset, limit) + err := o.ds.SelectContext(ctx, &addresses, stmt, o.routerContractAddress, offset, limit) if err != nil { return addresses, err } @@ -67,7 +66,7 @@ func (o *orm) GetAllowedSenders(offset, limit uint, qopts ...pg.QOpt) ([]common. return addresses, nil } -func (o *orm) CreateAllowedSenders(allowedSenders []common.Address, qopts ...pg.QOpt) error { +func (o *orm) CreateAllowedSenders(ctx context.Context, allowedSenders []common.Address) error { var valuesPlaceholder []string for i := 1; i <= len(allowedSenders)*2; i += 2 { valuesPlaceholder = append(valuesPlaceholder, fmt.Sprintf("($%d, $%d)", i, i+1)) @@ -82,7 +81,7 @@ func (o *orm) CreateAllowedSenders(allowedSenders []common.Address, qopts ...pg. args = append(args, as, o.routerContractAddress) } - _, err := o.q.WithOpts(qopts...).Exec(stmt, args...) + _, err := o.ds.ExecContext(ctx, stmt, args...) if err != nil { return err } @@ -94,7 +93,7 @@ func (o *orm) CreateAllowedSenders(allowedSenders []common.Address, qopts ...pg. // DeleteAllowedSenders is used to remove blocked senders from the functions_allowlist table. // This is achieved by specifying a list of blockedSenders to remove. -func (o *orm) DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg.QOpt) error { +func (o *orm) DeleteAllowedSenders(ctx context.Context, blockedSenders []common.Address) error { var valuesPlaceholder []string for i := 1; i <= len(blockedSenders); i++ { valuesPlaceholder = append(valuesPlaceholder, fmt.Sprintf("$%d", i+1)) @@ -110,7 +109,7 @@ func (o *orm) DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg. args = append(args, bs) } - res, err := o.q.WithOpts(qopts...).Exec(stmt, args...) + res, err := o.ds.ExecContext(ctx, stmt, args...) if err != nil { return err } @@ -126,12 +125,12 @@ func (o *orm) DeleteAllowedSenders(blockedSenders []common.Address, qopts ...pg. } // PurgeAllowedSenders will remove all the allowed senders for the configured orm routerContractAddress -func (o *orm) PurgeAllowedSenders(qopts ...pg.QOpt) error { +func (o *orm) PurgeAllowedSenders(ctx context.Context) error { stmt := fmt.Sprintf(` DELETE FROM %s WHERE router_contract_address = $1;`, tableName) - res, err := o.q.WithOpts(qopts...).Exec(stmt, o.routerContractAddress) + res, err := o.ds.ExecContext(ctx, stmt, o.routerContractAddress) if err != nil { return err } diff --git a/core/services/gateway/handlers/functions/allowlist/orm_test.go b/core/services/gateway/handlers/functions/allowlist/orm_test.go index 1d357616fab..2584e131968 100644 --- a/core/services/gateway/handlers/functions/allowlist/orm_test.go +++ b/core/services/gateway/handlers/functions/allowlist/orm_test.go @@ -20,17 +20,18 @@ func setupORM(t *testing.T) (allowlist.ORM, error) { lggr = logger.TestLogger(t) ) - return allowlist.NewORM(db, lggr, pgtest.NewQConfig(true), testutils.NewAddress()) + return allowlist.NewORM(db, lggr, testutils.NewAddress()) } func seedAllowedSenders(t *testing.T, orm allowlist.ORM, amount int) []common.Address { + ctx := testutils.Context(t) storedAllowedSenders := make([]common.Address, amount) for i := 0; i < amount; i++ { address := testutils.NewAddress() storedAllowedSenders[i] = address } - err := orm.CreateAllowedSenders(storedAllowedSenders) + err := orm.CreateAllowedSenders(ctx, storedAllowedSenders) require.NoError(t, err) return storedAllowedSenders @@ -38,20 +39,22 @@ func seedAllowedSenders(t *testing.T, orm allowlist.ORM, amount int) []common.Ad func TestORM_GetAllowedSenders(t *testing.T) { t.Parallel() t.Run("fetch first page", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) storedAllowedSenders := seedAllowedSenders(t, orm, 2) - results, err := orm.GetAllowedSenders(0, 1) + results, err := orm.GetAllowedSenders(ctx, 0, 1) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") require.Equal(t, storedAllowedSenders[0], results[0]) }) t.Run("fetch second page", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) storedAllowedSenders := seedAllowedSenders(t, orm, 2) - results, err := orm.GetAllowedSenders(1, 5) + results, err := orm.GetAllowedSenders(ctx, 1, 5) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") require.Equal(t, storedAllowedSenders[1], results[0]) @@ -62,42 +65,45 @@ func TestORM_CreateAllowedSenders(t *testing.T) { t.Parallel() t.Run("OK-create_an_allowed_sender", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) expected := testutils.NewAddress() - err = orm.CreateAllowedSenders([]common.Address{expected}) + err = orm.CreateAllowedSenders(ctx, []common.Address{expected}) require.NoError(t, err) - results, err := orm.GetAllowedSenders(0, 1) + results, err := orm.GetAllowedSenders(ctx, 0, 1) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") require.Equal(t, expected, results[0]) }) t.Run("OK-create_an_existing_allowed_sender", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) expected := testutils.NewAddress() - err = orm.CreateAllowedSenders([]common.Address{expected}) + err = orm.CreateAllowedSenders(ctx, []common.Address{expected}) require.NoError(t, err) - err = orm.CreateAllowedSenders([]common.Address{expected}) + err = orm.CreateAllowedSenders(ctx, []common.Address{expected}) require.NoError(t, err) - results, err := orm.GetAllowedSenders(0, 5) + results, err := orm.GetAllowedSenders(ctx, 0, 5) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") require.Equal(t, expected, results[0]) }) t.Run("OK-create_multiple_allowed_senders_in_one_query", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) expected := []common.Address{testutils.NewAddress(), testutils.NewAddress()} - err = orm.CreateAllowedSenders(expected) + err = orm.CreateAllowedSenders(ctx, expected) require.NoError(t, err) - results, err := orm.GetAllowedSenders(0, 2) + results, err := orm.GetAllowedSenders(ctx, 0, 2) require.NoError(t, err) require.Equal(t, 2, len(results), "incorrect results length") require.Equal(t, expected[0], results[0]) @@ -105,6 +111,7 @@ func TestORM_CreateAllowedSenders(t *testing.T) { }) t.Run("OK-create_multiple_allowed_senders_with_duplicates", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) addr1 := testutils.NewAddress() @@ -112,10 +119,10 @@ func TestORM_CreateAllowedSenders(t *testing.T) { expected := []common.Address{addr1, addr2} duplicatedAddressInput := []common.Address{addr1, addr1, addr1, addr2} - err = orm.CreateAllowedSenders(duplicatedAddressInput) + err = orm.CreateAllowedSenders(ctx, duplicatedAddressInput) require.NoError(t, err) - results, err := orm.GetAllowedSenders(0, 10) + results, err := orm.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 2, len(results), "incorrect results length") require.Equal(t, expected[0], results[0]) @@ -127,46 +134,48 @@ func TestORM_DeleteAllowedSenders(t *testing.T) { t.Parallel() t.Run("OK-delete_blocked_sender_from_allowed_list", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) add1 := testutils.NewAddress() add2 := testutils.NewAddress() add3 := testutils.NewAddress() - err = orm.CreateAllowedSenders([]common.Address{add1, add2, add3}) + err = orm.CreateAllowedSenders(ctx, []common.Address{add1, add2, add3}) require.NoError(t, err) - results, err := orm.GetAllowedSenders(0, 10) + results, err := orm.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 3, len(results), "incorrect results length") require.Equal(t, add1, results[0]) - err = orm.DeleteAllowedSenders([]common.Address{add1, add3}) + err = orm.DeleteAllowedSenders(ctx, []common.Address{add1, add3}) require.NoError(t, err) - results, err = orm.GetAllowedSenders(0, 10) + results, err = orm.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") require.Equal(t, add2, results[0]) }) t.Run("OK-delete_non_existing_blocked_sender_from_allowed_list", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) add1 := testutils.NewAddress() add2 := testutils.NewAddress() - err = orm.CreateAllowedSenders([]common.Address{add1, add2}) + err = orm.CreateAllowedSenders(ctx, []common.Address{add1, add2}) require.NoError(t, err) - results, err := orm.GetAllowedSenders(0, 10) + results, err := orm.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 2, len(results), "incorrect results length") require.Equal(t, add1, results[0]) add3 := testutils.NewAddress() - err = orm.DeleteAllowedSenders([]common.Address{add3}) + err = orm.DeleteAllowedSenders(ctx, []common.Address{add3}) require.NoError(t, err) - results, err = orm.GetAllowedSenders(0, 10) + results, err = orm.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 2, len(results), "incorrect results length") require.Equal(t, add1, results[0]) @@ -178,36 +187,38 @@ func TestORM_PurgeAllowedSenders(t *testing.T) { t.Parallel() t.Run("OK-purge_allowed_list", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) add1 := testutils.NewAddress() add2 := testutils.NewAddress() add3 := testutils.NewAddress() - err = orm.CreateAllowedSenders([]common.Address{add1, add2, add3}) + err = orm.CreateAllowedSenders(ctx, []common.Address{add1, add2, add3}) require.NoError(t, err) - results, err := orm.GetAllowedSenders(0, 10) + results, err := orm.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 3, len(results), "incorrect results length") require.Equal(t, add1, results[0]) - err = orm.PurgeAllowedSenders() + err = orm.PurgeAllowedSenders(ctx) require.NoError(t, err) - results, err = orm.GetAllowedSenders(0, 10) + results, err = orm.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 0, len(results), "incorrect results length") }) t.Run("OK-purge_allowed_list_for_contract_address", func(t *testing.T) { + ctx := testutils.Context(t) orm1, err := setupORM(t) require.NoError(t, err) add1 := testutils.NewAddress() add2 := testutils.NewAddress() - err = orm1.CreateAllowedSenders([]common.Address{add1, add2}) + err = orm1.CreateAllowedSenders(ctx, []common.Address{add1, add2}) require.NoError(t, err) - results, err := orm1.GetAllowedSenders(0, 10) + results, err := orm1.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 2, len(results), "incorrect results length") require.Equal(t, add1, results[0]) @@ -216,22 +227,22 @@ func TestORM_PurgeAllowedSenders(t *testing.T) { require.NoError(t, err) add3 := testutils.NewAddress() add4 := testutils.NewAddress() - err = orm2.CreateAllowedSenders([]common.Address{add3, add4}) + err = orm2.CreateAllowedSenders(ctx, []common.Address{add3, add4}) require.NoError(t, err) - results, err = orm2.GetAllowedSenders(0, 10) + results, err = orm2.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 2, len(results), "incorrect results length") require.Equal(t, add3, results[0]) - err = orm2.PurgeAllowedSenders() + err = orm2.PurgeAllowedSenders(ctx) require.NoError(t, err) - results, err = orm2.GetAllowedSenders(0, 10) + results, err = orm2.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 0, len(results), "incorrect results length") - results, err = orm1.GetAllowedSenders(0, 10) + results, err = orm1.GetAllowedSenders(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 2, len(results), "incorrect results length") require.Equal(t, add1, results[0]) @@ -241,15 +252,15 @@ func TestORM_PurgeAllowedSenders(t *testing.T) { func Test_NewORM(t *testing.T) { t.Run("OK-create_ORM", func(t *testing.T) { - _, err := allowlist.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), pgtest.NewQConfig(true), testutils.NewAddress()) + _, err := allowlist.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), testutils.NewAddress()) require.NoError(t, err) }) t.Run("NOK-create_ORM_with_nil_fields", func(t *testing.T) { - _, err := allowlist.NewORM(nil, nil, nil, common.Address{}) + _, err := allowlist.NewORM(nil, nil, common.Address{}) require.Error(t, err) }) t.Run("NOK-create_ORM_with_empty_address", func(t *testing.T) { - _, err := allowlist.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), pgtest.NewQConfig(true), common.Address{}) + _, err := allowlist.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), common.Address{}) require.Error(t, err) }) } diff --git a/core/services/gateway/handlers/functions/handler.functions.go b/core/services/gateway/handlers/functions/handler.functions.go index ff272e4e577..692534db598 100644 --- a/core/services/gateway/handlers/functions/handler.functions.go +++ b/core/services/gateway/handlers/functions/handler.functions.go @@ -114,7 +114,7 @@ func NewFunctionsHandlerFromConfig(handlerConfig json.RawMessage, donConfig *con return nil, err2 } - orm, err2 := fallow.NewORM(db, lggr, qcfg, cfg.OnchainAllowlist.ContractAddress) + orm, err2 := fallow.NewORM(db, lggr, cfg.OnchainAllowlist.ContractAddress) if err2 != nil { return nil, err2 } @@ -143,7 +143,7 @@ func NewFunctionsHandlerFromConfig(handlerConfig json.RawMessage, donConfig *con return nil, err2 } - orm, err2 := fsub.NewORM(db, lggr, qcfg, cfg.OnchainSubscriptions.ContractAddress) + orm, err2 := fsub.NewORM(db, lggr, cfg.OnchainSubscriptions.ContractAddress) if err2 != nil { return nil, err2 } diff --git a/core/services/gateway/handlers/functions/subscriptions/mocks/orm.go b/core/services/gateway/handlers/functions/subscriptions/mocks/orm.go index 0f278aa49b0..16a82a488b4 100644 --- a/core/services/gateway/handlers/functions/subscriptions/mocks/orm.go +++ b/core/services/gateway/handlers/functions/subscriptions/mocks/orm.go @@ -3,8 +3,9 @@ package mocks import ( + context "context" + subscriptions "github.com/smartcontractkit/chainlink/v2/core/services/gateway/handlers/functions/subscriptions" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" mock "github.com/stretchr/testify/mock" ) @@ -13,16 +14,9 @@ type ORM struct { mock.Mock } -// GetSubscriptions provides a mock function with given fields: offset, limit, qopts -func (_m *ORM) GetSubscriptions(offset uint, limit uint, qopts ...pg.QOpt) ([]subscriptions.StoredSubscription, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, offset, limit) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// GetSubscriptions provides a mock function with given fields: ctx, offset, limit +func (_m *ORM) GetSubscriptions(ctx context.Context, offset uint, limit uint) ([]subscriptions.StoredSubscription, error) { + ret := _m.Called(ctx, offset, limit) if len(ret) == 0 { panic("no return value specified for GetSubscriptions") @@ -30,19 +24,19 @@ func (_m *ORM) GetSubscriptions(offset uint, limit uint, qopts ...pg.QOpt) ([]su var r0 []subscriptions.StoredSubscription var r1 error - if rf, ok := ret.Get(0).(func(uint, uint, ...pg.QOpt) ([]subscriptions.StoredSubscription, error)); ok { - return rf(offset, limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint, uint) ([]subscriptions.StoredSubscription, error)); ok { + return rf(ctx, offset, limit) } - if rf, ok := ret.Get(0).(func(uint, uint, ...pg.QOpt) []subscriptions.StoredSubscription); ok { - r0 = rf(offset, limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint, uint) []subscriptions.StoredSubscription); ok { + r0 = rf(ctx, offset, limit) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]subscriptions.StoredSubscription) } } - if rf, ok := ret.Get(1).(func(uint, uint, ...pg.QOpt) error); ok { - r1 = rf(offset, limit, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, uint, uint) error); ok { + r1 = rf(ctx, offset, limit) } else { r1 = ret.Error(1) } @@ -50,24 +44,17 @@ func (_m *ORM) GetSubscriptions(offset uint, limit uint, qopts ...pg.QOpt) ([]su return r0, r1 } -// UpsertSubscription provides a mock function with given fields: subscription, qopts -func (_m *ORM) UpsertSubscription(subscription subscriptions.StoredSubscription, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, subscription) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// UpsertSubscription provides a mock function with given fields: ctx, subscription +func (_m *ORM) UpsertSubscription(ctx context.Context, subscription subscriptions.StoredSubscription) error { + ret := _m.Called(ctx, subscription) if len(ret) == 0 { panic("no return value specified for UpsertSubscription") } var r0 error - if rf, ok := ret.Get(0).(func(subscriptions.StoredSubscription, ...pg.QOpt) error); ok { - r0 = rf(subscription, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, subscriptions.StoredSubscription) error); ok { + r0 = rf(ctx, subscription) } else { r0 = ret.Error(0) } diff --git a/core/services/gateway/handlers/functions/subscriptions/orm.go b/core/services/gateway/handlers/functions/subscriptions/orm.go index 369291ace54..d97437a39dc 100644 --- a/core/services/gateway/handlers/functions/subscriptions/orm.go +++ b/core/services/gateway/handlers/functions/subscriptions/orm.go @@ -1,6 +1,7 @@ package subscriptions import ( + "context" "fmt" "math/big" @@ -8,21 +9,19 @@ import ( "github.com/lib/pq" "github.com/pkg/errors" - "github.com/jmoiron/sqlx" - + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/functions/generated/functions_router" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) //go:generate mockery --quiet --name ORM --output ./mocks/ --case=underscore type ORM interface { - GetSubscriptions(offset, limit uint, qopts ...pg.QOpt) ([]StoredSubscription, error) - UpsertSubscription(subscription StoredSubscription, qopts ...pg.QOpt) error + GetSubscriptions(ctx context.Context, offset, limit uint) ([]StoredSubscription, error) + UpsertSubscription(ctx context.Context, subscription StoredSubscription) error } type orm struct { - q pg.Q + ds sqlutil.DataSource lggr logger.Logger routerContractAddress common.Address } @@ -47,19 +46,19 @@ type storedSubscriptionRow struct { RouterContractAddress common.Address } -func NewORM(db *sqlx.DB, lggr logger.Logger, cfg pg.QConfig, routerContractAddress common.Address) (ORM, error) { - if db == nil || cfg == nil || lggr == nil || routerContractAddress == (common.Address{}) { +func NewORM(ds sqlutil.DataSource, lggr logger.Logger, routerContractAddress common.Address) (ORM, error) { + if ds == nil || lggr == nil || routerContractAddress == (common.Address{}) { return nil, ErrInvalidParameters } return &orm{ - q: pg.NewQ(db, lggr, cfg), + ds: ds, lggr: lggr, routerContractAddress: routerContractAddress, }, nil } -func (o *orm) GetSubscriptions(offset, limit uint, qopts ...pg.QOpt) ([]StoredSubscription, error) { +func (o *orm) GetSubscriptions(ctx context.Context, offset, limit uint) ([]StoredSubscription, error) { var storedSubscriptions []StoredSubscription var storedSubscriptionRows []storedSubscriptionRow stmt := fmt.Sprintf(` @@ -70,7 +69,7 @@ func (o *orm) GetSubscriptions(offset, limit uint, qopts ...pg.QOpt) ([]StoredSu OFFSET $2 LIMIT $3; `, tableName) - err := o.q.WithOpts(qopts...).Select(&storedSubscriptionRows, stmt, o.routerContractAddress, offset, limit) + err := o.ds.SelectContext(ctx, &storedSubscriptionRows, stmt, o.routerContractAddress, offset, limit) if err != nil { return storedSubscriptions, err } @@ -84,7 +83,7 @@ func (o *orm) GetSubscriptions(offset, limit uint, qopts ...pg.QOpt) ([]StoredSu // UpsertSubscription will update if a subscription exists or create if it does not. // In case a subscription gets deleted we will update it with an owner address equal to 0x0. -func (o *orm) UpsertSubscription(subscription StoredSubscription, qopts ...pg.QOpt) error { +func (o *orm) UpsertSubscription(ctx context.Context, subscription StoredSubscription) error { stmt := fmt.Sprintf(` INSERT INTO %s (subscription_id, owner, balance, blocked_balance, proposed_owner, consumers, flags, router_contract_address) VALUES ($1,$2,$3,$4,$5,$6,$7,$8) ON CONFLICT (subscription_id, router_contract_address) DO UPDATE @@ -103,7 +102,8 @@ func (o *orm) UpsertSubscription(subscription StoredSubscription, qopts ...pg.QO consumers = append(consumers, c.Bytes()) } - _, err := o.q.WithOpts(qopts...).Exec( + _, err := o.ds.ExecContext( + ctx, stmt, subscription.SubscriptionID, subscription.Owner, diff --git a/core/services/gateway/handlers/functions/subscriptions/orm_test.go b/core/services/gateway/handlers/functions/subscriptions/orm_test.go index 6cb1146f03c..f75ab0b98c1 100644 --- a/core/services/gateway/handlers/functions/subscriptions/orm_test.go +++ b/core/services/gateway/handlers/functions/subscriptions/orm_test.go @@ -27,10 +27,11 @@ func setupORM(t *testing.T) (subscriptions.ORM, error) { lggr = logger.TestLogger(t) ) - return subscriptions.NewORM(db, lggr, pgtest.NewQConfig(true), testutils.NewAddress()) + return subscriptions.NewORM(db, lggr, testutils.NewAddress()) } func seedSubscriptions(t *testing.T, orm subscriptions.ORM, amount int) []subscriptions.StoredSubscription { + ctx := testutils.Context(t) storedSubscriptions := make([]subscriptions.StoredSubscription, 0) for i := amount; i > 0; i-- { cs := subscriptions.StoredSubscription{ @@ -45,7 +46,7 @@ func seedSubscriptions(t *testing.T, orm subscriptions.ORM, amount int) []subscr }, } storedSubscriptions = append(storedSubscriptions, cs) - err := orm.UpsertSubscription(cs) + err := orm.UpsertSubscription(ctx, cs) require.NoError(t, err) } return storedSubscriptions @@ -54,20 +55,22 @@ func seedSubscriptions(t *testing.T, orm subscriptions.ORM, amount int) []subscr func TestORM_GetSubscriptions(t *testing.T) { t.Parallel() t.Run("fetch first page", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) storedSubscriptions := seedSubscriptions(t, orm, 2) - results, err := orm.GetSubscriptions(0, 1) + results, err := orm.GetSubscriptions(ctx, 0, 1) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") require.Equal(t, storedSubscriptions[1], results[0]) }) t.Run("fetch second page", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) storedSubscriptions := seedSubscriptions(t, orm, 2) - results, err := orm.GetSubscriptions(1, 5) + results, err := orm.GetSubscriptions(ctx, 1, 5) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") require.Equal(t, storedSubscriptions[0], results[0]) @@ -78,6 +81,7 @@ func TestORM_UpsertSubscription(t *testing.T) { t.Parallel() t.Run("create a subscription", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) expected := subscriptions.StoredSubscription{ @@ -91,16 +95,17 @@ func TestORM_UpsertSubscription(t *testing.T) { Flags: defaultFlags, }, } - err = orm.UpsertSubscription(expected) + err = orm.UpsertSubscription(ctx, expected) require.NoError(t, err) - results, err := orm.GetSubscriptions(0, 1) + results, err := orm.GetSubscriptions(ctx, 0, 1) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") require.Equal(t, expected, results[0]) }) t.Run("update a subscription", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) @@ -115,7 +120,7 @@ func TestORM_UpsertSubscription(t *testing.T) { Flags: defaultFlags, }, } - err = orm.UpsertSubscription(expectedUpdated) + err = orm.UpsertSubscription(ctx, expectedUpdated) require.NoError(t, err) expectedNotUpdated := subscriptions.StoredSubscription{ @@ -129,15 +134,15 @@ func TestORM_UpsertSubscription(t *testing.T) { Flags: defaultFlags, }, } - err = orm.UpsertSubscription(expectedNotUpdated) + err = orm.UpsertSubscription(ctx, expectedNotUpdated) require.NoError(t, err) // update the balance value expectedUpdated.Balance = big.NewInt(20) - err = orm.UpsertSubscription(expectedUpdated) + err = orm.UpsertSubscription(ctx, expectedUpdated) require.NoError(t, err) - results, err := orm.GetSubscriptions(0, 5) + results, err := orm.GetSubscriptions(ctx, 0, 5) require.NoError(t, err) require.Equal(t, 2, len(results), "incorrect results length") require.Equal(t, expectedNotUpdated, results[1]) @@ -145,6 +150,7 @@ func TestORM_UpsertSubscription(t *testing.T) { }) t.Run("update a deleted subscription", func(t *testing.T) { + ctx := testutils.Context(t) orm, err := setupORM(t) require.NoError(t, err) @@ -159,7 +165,7 @@ func TestORM_UpsertSubscription(t *testing.T) { Flags: defaultFlags, }, } - err = orm.UpsertSubscription(subscription) + err = orm.UpsertSubscription(ctx, subscription) require.NoError(t, err) // empty subscription @@ -172,24 +178,25 @@ func TestORM_UpsertSubscription(t *testing.T) { Flags: [32]byte{}, } - err = orm.UpsertSubscription(subscription) + err = orm.UpsertSubscription(ctx, subscription) require.NoError(t, err) - results, err := orm.GetSubscriptions(0, 5) + results, err := orm.GetSubscriptions(ctx, 0, 5) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") require.Equal(t, subscription, results[0]) }) t.Run("create a subscription with same id but different router address", func(t *testing.T) { + ctx := testutils.Context(t) var ( db = pgtest.NewSqlxDB(t) lggr = logger.TestLogger(t) ) - orm1, err := subscriptions.NewORM(db, lggr, pgtest.NewQConfig(true), testutils.NewAddress()) + orm1, err := subscriptions.NewORM(db, lggr, testutils.NewAddress()) require.NoError(t, err) - orm2, err := subscriptions.NewORM(db, lggr, pgtest.NewQConfig(true), testutils.NewAddress()) + orm2, err := subscriptions.NewORM(db, lggr, testutils.NewAddress()) require.NoError(t, err) subscription := subscriptions.StoredSubscription{ @@ -204,42 +211,42 @@ func TestORM_UpsertSubscription(t *testing.T) { }, } - err = orm1.UpsertSubscription(subscription) + err = orm1.UpsertSubscription(ctx, subscription) require.NoError(t, err) // should update the existing subscription subscription.Balance = assets.Ether(12).ToInt() - err = orm1.UpsertSubscription(subscription) + err = orm1.UpsertSubscription(ctx, subscription) require.NoError(t, err) - results, err := orm1.GetSubscriptions(0, 10) + results, err := orm1.GetSubscriptions(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") // should create a new subscription because it comes from a different router contract - err = orm2.UpsertSubscription(subscription) + err = orm2.UpsertSubscription(ctx, subscription) require.NoError(t, err) - results, err = orm1.GetSubscriptions(0, 10) + results, err = orm1.GetSubscriptions(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") - results, err = orm2.GetSubscriptions(0, 10) + results, err = orm2.GetSubscriptions(ctx, 0, 10) require.NoError(t, err) require.Equal(t, 1, len(results), "incorrect results length") }) } func Test_NewORM(t *testing.T) { t.Run("OK-create_ORM", func(t *testing.T) { - _, err := subscriptions.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), pgtest.NewQConfig(true), testutils.NewAddress()) + _, err := subscriptions.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), testutils.NewAddress()) require.NoError(t, err) }) t.Run("NOK-create_ORM_with_nil_fields", func(t *testing.T) { - _, err := subscriptions.NewORM(nil, nil, nil, common.Address{}) + _, err := subscriptions.NewORM(nil, nil, common.Address{}) require.Error(t, err) }) t.Run("NOK-create_ORM_with_empty_address", func(t *testing.T) { - _, err := subscriptions.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), pgtest.NewQConfig(true), common.Address{}) + _, err := subscriptions.NewORM(pgtest.NewSqlxDB(t), logger.TestLogger(t), common.Address{}) require.Error(t, err) }) } diff --git a/core/services/gateway/handlers/functions/subscriptions/subscriptions.go b/core/services/gateway/handlers/functions/subscriptions/subscriptions.go index e90201a31a9..d481ecf12ed 100644 --- a/core/services/gateway/handlers/functions/subscriptions/subscriptions.go +++ b/core/services/gateway/handlers/functions/subscriptions/subscriptions.go @@ -99,7 +99,7 @@ func (s *onchainSubscriptions) Start(ctx context.Context) error { return errors.New("OnchainSubscriptionsConfig.UpdateRangeSize must be greater than 0") } - s.loadStoredSubscriptions() + s.loadStoredSubscriptions(ctx) s.closeWait.Add(1) go s.queryLoop() @@ -206,7 +206,7 @@ func (s *onchainSubscriptions) querySubscriptionsRange(ctx context.Context, bloc subscription := subscription updated := s.subscriptions.UpdateSubscription(subscriptionId, &subscription) if updated { - if err = s.orm.UpsertSubscription(StoredSubscription{ + if err = s.orm.UpsertSubscription(ctx, StoredSubscription{ SubscriptionID: subscriptionId, IFunctionsSubscriptionsSubscription: subscription, }); err != nil { @@ -226,10 +226,10 @@ func (s *onchainSubscriptions) getSubscriptionsCount(ctx context.Context, blockN }) } -func (s *onchainSubscriptions) loadStoredSubscriptions() { +func (s *onchainSubscriptions) loadStoredSubscriptions(ctx context.Context) { offset := uint(0) for { - csBatch, err := s.orm.GetSubscriptions(offset, s.config.StoreBatchSize) + csBatch, err := s.orm.GetSubscriptions(ctx, offset, s.config.StoreBatchSize) if err != nil { break } diff --git a/core/services/gateway/handlers/functions/subscriptions/subscriptions_test.go b/core/services/gateway/handlers/functions/subscriptions/subscriptions_test.go index be1d2520434..212029b73f7 100644 --- a/core/services/gateway/handlers/functions/subscriptions/subscriptions_test.go +++ b/core/services/gateway/handlers/functions/subscriptions/subscriptions_test.go @@ -51,8 +51,8 @@ func TestSubscriptions_OnePass(t *testing.T) { UpdateRangeSize: 3, } orm := smocks.NewORM(t) - orm.On("GetSubscriptions", uint(0), uint(100)).Return([]subscriptions.StoredSubscription{}, nil) - orm.On("UpsertSubscription", mock.Anything).Return(nil) + orm.On("GetSubscriptions", mock.Anything, uint(0), uint(100)).Return([]subscriptions.StoredSubscription{}, nil) + orm.On("UpsertSubscription", mock.Anything, mock.Anything).Return(nil) subscriptions, err := subscriptions.NewOnchainSubscriptions(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) @@ -102,8 +102,8 @@ func TestSubscriptions_MultiPass(t *testing.T) { UpdateRangeSize: 3, } orm := smocks.NewORM(t) - orm.On("GetSubscriptions", uint(0), uint(100)).Return([]subscriptions.StoredSubscription{}, nil) - orm.On("UpsertSubscription", mock.Anything).Return(nil) + orm.On("GetSubscriptions", mock.Anything, uint(0), uint(100)).Return([]subscriptions.StoredSubscription{}, nil) + orm.On("UpsertSubscription", mock.Anything, mock.Anything).Return(nil) subscriptions, err := subscriptions.NewOnchainSubscriptions(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) @@ -144,7 +144,7 @@ func TestSubscriptions_Stored(t *testing.T) { expectedBalance := big.NewInt(5) orm := smocks.NewORM(t) - orm.On("GetSubscriptions", uint(0), uint(1)).Return([]subscriptions.StoredSubscription{ + orm.On("GetSubscriptions", mock.Anything, uint(0), uint(1)).Return([]subscriptions.StoredSubscription{ { SubscriptionID: 1, IFunctionsSubscriptionsSubscription: functions_router.IFunctionsSubscriptionsSubscription{ @@ -154,8 +154,8 @@ func TestSubscriptions_Stored(t *testing.T) { }, }, }, nil) - orm.On("GetSubscriptions", uint(1), uint(1)).Return([]subscriptions.StoredSubscription{}, nil) - orm.On("UpsertSubscription", mock.Anything).Return(nil) + orm.On("GetSubscriptions", mock.Anything, uint(1), uint(1)).Return([]subscriptions.StoredSubscription{}, nil) + orm.On("UpsertSubscription", mock.Anything, mock.Anything).Return(nil) subscriptions, err := subscriptions.NewOnchainSubscriptions(client, config, orm, logger.TestLogger(t)) require.NoError(t, err) diff --git a/core/services/job/spawner_test.go b/core/services/job/spawner_test.go index 3ac32309775..d2e7a80d5d4 100644 --- a/core/services/job/spawner_test.go +++ b/core/services/job/spawner_test.go @@ -306,7 +306,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { processConfig := plugins.NewRegistrarConfig(loop.GRPCOpts{}, func(name string) (*plugins.RegisteredLoop, error) { return nil, nil }, func(loopId string) {}) ocr2DelegateConfig := ocr2.NewDelegateConfig(config.OCR2(), config.Mercury(), config.Threshold(), config.Insecure(), config.JobPipeline(), config.Database(), processConfig) - d := ocr2.NewDelegate(nil, orm, nil, nil, nil, nil, nil, monitoringEndpoint, legacyChains, lggr, ocr2DelegateConfig, + d := ocr2.NewDelegate(nil, nil, orm, nil, nil, nil, nil, nil, monitoringEndpoint, legacyChains, lggr, ocr2DelegateConfig, keyStore.OCR2(), keyStore.DKGSign(), keyStore.DKGEncrypt(), ethKeyStore, testRelayGetter, mailMon, capabilities.NewRegistry(lggr)) delegateOCR2 := &delegate{jobOCR2VRF.Type, []job.ServiceCtx{}, 0, nil, d} diff --git a/core/services/ocr2/delegate.go b/core/services/ocr2/delegate.go index 7b4200efd68..a00ed195903 100644 --- a/core/services/ocr2/delegate.go +++ b/core/services/ocr2/delegate.go @@ -27,6 +27,7 @@ import ( ocr2keepers21config "github.com/smartcontractkit/chainlink-automation/pkg/v3/config" ocr2keepers21 "github.com/smartcontractkit/chainlink-automation/pkg/v3/plugin" "github.com/smartcontractkit/chainlink-common/pkg/loop/reportingplugins/ocr3" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/config/env" "github.com/smartcontractkit/chainlink-vrf/altbn_128" @@ -109,7 +110,8 @@ type RelayGetter interface { Get(id relay.ID) (loop.Relayer, error) } type Delegate struct { - db *sqlx.DB + db *sqlx.DB // legacy: prefer to use ds instead + ds sqlutil.DataSource jobORM job.ORM bridgeORM bridges.ORM mercuryORM evmmercury.ORM @@ -223,6 +225,7 @@ var _ job.Delegate = (*Delegate)(nil) func NewDelegate( db *sqlx.DB, + ds sqlutil.DataSource, jobORM job.ORM, bridgeORM bridges.ORM, mercuryORM evmmercury.ORM, @@ -243,6 +246,7 @@ func NewDelegate( ) *Delegate { return &Delegate{ db: db, + ds: ds, jobORM: jobORM, bridgeORM: bridgeORM, mercuryORM: mercuryORM, @@ -1669,8 +1673,7 @@ func (d *Delegate) newServicesOCR2Functions( Job: jb, JobORM: d.jobORM, BridgeORM: d.bridgeORM, - QConfig: d.cfg.Database(), - DB: d.db, + DS: d.ds, Chain: chain, ContractID: spec.ContractID, Logger: lggr, diff --git a/core/services/ocr2/plugins/functions/plugin.go b/core/services/ocr2/plugins/functions/plugin.go index 92b15892885..d6ffa1a3f06 100644 --- a/core/services/ocr2/plugins/functions/plugin.go +++ b/core/services/ocr2/plugins/functions/plugin.go @@ -8,13 +8,13 @@ import ( "time" "github.com/ethereum/go-ethereum/common" - "github.com/jmoiron/sqlx" "github.com/jonboulle/clockwork" "github.com/pkg/errors" "github.com/smartcontractkit/libocr/commontypes" libocr2 "github.com/smartcontractkit/libocr/offchainreporting2plus" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" "github.com/smartcontractkit/chainlink/v2/core/bridges" @@ -31,7 +31,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/config" s4_plugin "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/s4" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/threshold" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" evmrelayTypes "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/types" "github.com/smartcontractkit/chainlink/v2/core/services/s4" ) @@ -40,8 +39,7 @@ type FunctionsServicesConfig struct { Job job.Job JobORM job.ORM BridgeORM bridges.ORM - QConfig pg.QConfig - DB *sqlx.DB + DS sqlutil.DataSource Chain legacyevm.Chain ContractID string Logger logger.Logger @@ -63,8 +61,8 @@ const ( // Create all OCR2 plugin Oracles and all extra services needed to run a Functions job. func NewFunctionsServices(ctx context.Context, functionsOracleArgs, thresholdOracleArgs, s4OracleArgs *libocr2.OCR2OracleArgs, conf *FunctionsServicesConfig) ([]job.ServiceCtx, error) { - pluginORM := functions.NewORM(conf.DB, conf.Logger, conf.QConfig, common.HexToAddress(conf.ContractID)) - s4ORM := s4.NewCachedORMWrapper(s4.NewPostgresORM(conf.DB, conf.Logger, conf.QConfig, s4.SharedTableName, FunctionsS4Namespace), conf.Logger) + pluginORM := functions.NewORM(conf.DS, common.HexToAddress(conf.ContractID)) + s4ORM := s4.NewCachedORMWrapper(s4.NewPostgresORM(conf.DS, s4.SharedTableName, FunctionsS4Namespace), conf.Logger) var pluginConfig config.PluginConfig if err := json.Unmarshal(conf.Job.OCR2OracleSpec.PluginConfig.Bytes(), &pluginConfig); err != nil { @@ -155,7 +153,7 @@ func NewFunctionsServices(ctx context.Context, functionsOracleArgs, thresholdOra allServices = append(allServices, job.NewServiceAdapter(functionsReportingPluginOracle)) if pluginConfig.GatewayConnectorConfig != nil && s4Storage != nil && pluginConfig.OnchainAllowlist != nil && pluginConfig.RateLimiter != nil && pluginConfig.OnchainSubscriptions != nil { - allowlistORM, err := gwAllowlist.NewORM(conf.DB, conf.Logger, conf.QConfig, pluginConfig.OnchainAllowlist.ContractAddress) + allowlistORM, err := gwAllowlist.NewORM(conf.DS, conf.Logger, pluginConfig.OnchainAllowlist.ContractAddress) if err != nil { return nil, errors.Wrap(err, "failed to create allowlist ORM") } @@ -167,7 +165,7 @@ func NewFunctionsServices(ctx context.Context, functionsOracleArgs, thresholdOra if err2 != nil { return nil, errors.Wrap(err, "failed to create a RateLimiter") } - subscriptionsORM, err := gwSubscriptions.NewORM(conf.DB, conf.Logger, conf.QConfig, pluginConfig.OnchainSubscriptions.ContractAddress) + subscriptionsORM, err := gwSubscriptions.NewORM(conf.DS, conf.Logger, pluginConfig.OnchainSubscriptions.ContractAddress) if err != nil { return nil, errors.Wrap(err, "failed to create subscriptions ORM") } diff --git a/core/services/ocr2/plugins/functions/reporting.go b/core/services/ocr2/plugins/functions/reporting.go index 36e8a882734..d9d68ec9097 100644 --- a/core/services/ocr2/plugins/functions/reporting.go +++ b/core/services/ocr2/plugins/functions/reporting.go @@ -18,7 +18,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/functions" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/config" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/functions/encoding" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) type FunctionsReportingPluginFactory struct { @@ -151,7 +150,7 @@ func (r *functionsReporting) Query(ctx context.Context, ts types.ReportTimestamp "oracleID": r.genericConfig.OracleID, }) maxBatchSize := r.specificConfig.Config.GetMaxRequestBatchSize() - results, err := r.pluginORM.FindOldestEntriesByState(functions.RESULT_READY, maxBatchSize, pg.WithParentCtx(ctx)) + results, err := r.pluginORM.FindOldestEntriesByState(ctx, functions.RESULT_READY, maxBatchSize) if err != nil { return nil, err } @@ -222,7 +221,7 @@ func (r *functionsReporting) Observation(ctx context.Context, ts types.ReportTim continue } processedIds[id] = true - localResult, err2 := r.pluginORM.FindById(id, pg.WithParentCtx(ctx)) + localResult, err2 := r.pluginORM.FindById(ctx, id) if err2 != nil { r.logger.Debug("FunctionsReporting Observation can't find request from query", commontypes.LogFields{ "requestID": formatRequestId(id[:]), @@ -429,14 +428,14 @@ func (r *functionsReporting) ShouldAcceptFinalizedReport(ctx context.Context, ts r.logger.Error("FunctionsReporting ShouldAcceptFinalizedReport: invalid ID", commontypes.LogFields{"requestID": reqIdStr, "err": err}) continue } - _, err = r.pluginORM.FindById(id, pg.WithParentCtx(ctx)) + _, err = r.pluginORM.FindById(ctx, id) if err != nil { // TODO: Differentiate between ID not found and other ORM errors (https://smartcontract-it.atlassian.net/browse/DRO-215) r.logger.Warn("FunctionsReporting ShouldAcceptFinalizedReport: request doesn't exist locally! Accepting anyway.", commontypes.LogFields{"requestID": reqIdStr}) needTransmissionIds = append(needTransmissionIds, reqIdStr) continue } - err = r.pluginORM.SetFinalized(id, item.Result, item.Error, pg.WithParentCtx(ctx)) // validates state transition + err = r.pluginORM.SetFinalized(ctx, id, item.Result, item.Error) // validates state transition if err != nil { r.logger.Debug("FunctionsReporting ShouldAcceptFinalizedReport: state couldn't be changed to FINALIZED. Not transmitting.", commontypes.LogFields{"requestID": reqIdStr, "err": err}) continue @@ -490,7 +489,7 @@ func (r *functionsReporting) ShouldTransmitAcceptedReport(ctx context.Context, t r.logger.Error("FunctionsReporting ShouldAcceptFinalizedReport: invalid ID", commontypes.LogFields{"requestID": reqIdStr, "err": err}) continue } - request, err := r.pluginORM.FindById(id, pg.WithParentCtx(ctx)) + request, err := r.pluginORM.FindById(ctx, id) if err != nil { r.logger.Warn("FunctionsReporting ShouldTransmitAcceptedReport: request doesn't exist locally! Transmitting anyway.", commontypes.LogFields{"requestID": reqIdStr, "err": err}) needTransmissionIds = append(needTransmissionIds, reqIdStr) diff --git a/core/services/ocr2/plugins/functions/reporting_test.go b/core/services/ocr2/plugins/functions/reporting_test.go index 5b9f59ccb23..7d6686a0b4f 100644 --- a/core/services/ocr2/plugins/functions/reporting_test.go +++ b/core/services/ocr2/plugins/functions/reporting_test.go @@ -134,7 +134,7 @@ func TestFunctionsReporting_Query(t *testing.T) { const batchSize = 10 plugin, orm, _, _ := preparePlugin(t, batchSize, 0) reqs := []functions_srv.Request{newRequest(), newRequest()} - orm.On("FindOldestEntriesByState", functions_srv.RESULT_READY, uint32(batchSize), mock.Anything).Return(reqs, nil) + orm.On("FindOldestEntriesByState", mock.Anything, functions_srv.RESULT_READY, uint32(batchSize), mock.Anything).Return(reqs, nil) q, err := plugin.Query(testutils.Context(t), types.ReportTimestamp{}) require.NoError(t, err) @@ -154,7 +154,7 @@ func TestFunctionsReporting_Query_HandleCoordinatorMismatch(t *testing.T) { reqs := []functions_srv.Request{newRequest(), newRequest()} reqs[0].CoordinatorContractAddress = &common.Address{1} reqs[1].CoordinatorContractAddress = &common.Address{2} - orm.On("FindOldestEntriesByState", functions_srv.RESULT_READY, uint32(batchSize), mock.Anything).Return(reqs, nil) + orm.On("FindOldestEntriesByState", mock.Anything, functions_srv.RESULT_READY, uint32(batchSize), mock.Anything).Return(reqs, nil) q, err := plugin.Query(testutils.Context(t), types.ReportTimestamp{}) require.NoError(t, err) @@ -177,11 +177,11 @@ func TestFunctionsReporting_Observation(t *testing.T) { req4 := newRequestTimedOut() nonexistentId := newRequestID() - orm.On("FindById", req1.RequestID, mock.Anything).Return(&req1, nil) - orm.On("FindById", req2.RequestID, mock.Anything).Return(&req2, nil) - orm.On("FindById", req3.RequestID, mock.Anything).Return(&req3, nil) - orm.On("FindById", req4.RequestID, mock.Anything).Return(&req4, nil) - orm.On("FindById", nonexistentId, mock.Anything).Return(nil, errors.New("nonexistent ID")) + orm.On("FindById", mock.Anything, req1.RequestID, mock.Anything).Return(&req1, nil) + orm.On("FindById", mock.Anything, req2.RequestID, mock.Anything).Return(&req2, nil) + orm.On("FindById", mock.Anything, req3.RequestID, mock.Anything).Return(&req3, nil) + orm.On("FindById", mock.Anything, req4.RequestID, mock.Anything).Return(&req4, nil) + orm.On("FindById", mock.Anything, nonexistentId, mock.Anything).Return(nil, errors.New("nonexistent ID")) // Query asking for 5 requests (with duplicates), out of which: // - two are ready @@ -209,7 +209,7 @@ func TestFunctionsReporting_Observation_IncorrectQuery(t *testing.T) { req1 := newRequestWithResult([]byte("abc")) invalidId := []byte("invalid") - orm.On("FindById", req1.RequestID, mock.Anything).Return(&req1, nil) + orm.On("FindById", mock.Anything, req1.RequestID, mock.Anything).Return(&req1, nil) // Query asking for 3 requests (with duplicates), out of which: // - two are invalid @@ -441,13 +441,13 @@ func TestFunctionsReporting_ShouldAcceptFinalizedReport(t *testing.T) { req3 := newRequestFinalized() req4 := newRequestTimedOut() - orm.On("FindById", req1.RequestID, mock.Anything).Return(nil, errors.New("nonexistent ID")) - orm.On("FindById", req2.RequestID, mock.Anything).Return(&req2, nil) - orm.On("SetFinalized", req2.RequestID, mock.Anything, mock.Anything, mock.Anything).Return(nil) - orm.On("FindById", req3.RequestID, mock.Anything).Return(&req3, nil) - orm.On("SetFinalized", req3.RequestID, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("same state")) - orm.On("FindById", req4.RequestID, mock.Anything).Return(&req4, nil) - orm.On("SetFinalized", req4.RequestID, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("already timed out")) + orm.On("FindById", mock.Anything, req1.RequestID, mock.Anything).Return(nil, errors.New("nonexistent ID")) + orm.On("FindById", mock.Anything, req2.RequestID, mock.Anything).Return(&req2, nil) + orm.On("SetFinalized", mock.Anything, req2.RequestID, mock.Anything, mock.Anything, mock.Anything).Return(nil) + orm.On("FindById", mock.Anything, req3.RequestID, mock.Anything).Return(&req3, nil) + orm.On("SetFinalized", mock.Anything, req3.RequestID, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("same state")) + orm.On("FindById", mock.Anything, req4.RequestID, mock.Anything).Return(&req4, nil) + orm.On("SetFinalized", mock.Anything, req4.RequestID, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("already timed out")) // Attempting to transmit 2 requests, out of which: // - one was already accepted for transmission earlier @@ -477,8 +477,8 @@ func TestFunctionsReporting_ShouldAcceptFinalizedReport_OffchainTransmission(t * req1 := newRequestWithResult([]byte("abc")) req1.OnchainMetadata = []byte(functions_srv.OffchainRequestMarker) - orm.On("FindById", req1.RequestID, mock.Anything).Return(&req1, nil) - orm.On("SetFinalized", req1.RequestID, mock.Anything, mock.Anything, mock.Anything).Return(nil) + orm.On("FindById", mock.Anything, req1.RequestID, mock.Anything).Return(&req1, nil) + orm.On("SetFinalized", mock.Anything, req1.RequestID, mock.Anything, mock.Anything, mock.Anything).Return(nil) offchainTransmitter.On("TransmitReport", mock.Anything, mock.Anything).Return(nil) should, err := plugin.ShouldAcceptFinalizedReport(testutils.Context(t), types.ReportTimestamp{}, getReportBytes(t, codec, req1)) @@ -496,11 +496,11 @@ func TestFunctionsReporting_ShouldTransmitAcceptedReport(t *testing.T) { req4 := newRequestTimedOut() req5 := newRequestConfirmed() - orm.On("FindById", req1.RequestID, mock.Anything).Return(nil, errors.New("nonexistent ID")) - orm.On("FindById", req2.RequestID, mock.Anything).Return(&req2, nil) - orm.On("FindById", req3.RequestID, mock.Anything).Return(&req3, nil) - orm.On("FindById", req4.RequestID, mock.Anything).Return(&req4, nil) - orm.On("FindById", req5.RequestID, mock.Anything).Return(&req5, nil) + orm.On("FindById", mock.Anything, req1.RequestID, mock.Anything).Return(nil, errors.New("nonexistent ID")) + orm.On("FindById", mock.Anything, req2.RequestID, mock.Anything).Return(&req2, nil) + orm.On("FindById", mock.Anything, req3.RequestID, mock.Anything).Return(&req3, nil) + orm.On("FindById", mock.Anything, req4.RequestID, mock.Anything).Return(&req4, nil) + orm.On("FindById", mock.Anything, req5.RequestID, mock.Anything).Return(&req5, nil) // Attempting to transmit 2 requests, out of which: // - one was already confirmed on chain diff --git a/core/services/ocr2/plugins/s4/integration_test.go b/core/services/ocr2/plugins/s4/integration_test.go index 8efe38f8e2d..5148ea6e26d 100644 --- a/core/services/ocr2/plugins/s4/integration_test.go +++ b/core/services/ocr2/plugins/s4/integration_test.go @@ -15,7 +15,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/s4" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" s4_svc "github.com/smartcontractkit/chainlink/v2/core/services/s4" commonlogger "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -53,7 +52,7 @@ func newDON(t *testing.T, size int, config *s4.PluginConfig) *don { for i := 0; i < size; i++ { ns := fmt.Sprintf("s4_int_test_%d", i) - orm := s4_svc.NewPostgresORM(db, logger, pgtest.NewQConfig(false), s4_svc.SharedTableName, ns) + orm := s4_svc.NewPostgresORM(db, s4_svc.SharedTableName, ns) orms[i] = orm ocrLogger := commonlogger.NewOCRWrapper(logger, true, func(msg string) {}) @@ -149,7 +148,7 @@ func checkNoErrors(t *testing.T, errors []error) { func checkNoUnconfirmedRows(ctx context.Context, t *testing.T, orm s4_svc.ORM, limit uint) { t.Helper() - rows, err := orm.GetUnconfirmedRows(limit, pg.WithParentCtx(ctx)) + rows, err := orm.GetUnconfirmedRows(ctx, limit) assert.NoError(t, err) assert.Empty(t, rows) } @@ -161,10 +160,10 @@ func TestS4Integration_HappyDON(t *testing.T) { // injecting new records rows := generateTestOrmRows(t, 10, time.Minute) for _, row := range rows { - err := don.orms[0].Update(row, pg.WithParentCtx(ctx)) + err := don.orms[0].Update(ctx, row) require.NoError(t, err) } - originSnapshot, err := don.orms[0].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + originSnapshot, err := don.orms[0].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) // S4 to propagate all records in one OCR round @@ -172,7 +171,7 @@ func TestS4Integration_HappyDON(t *testing.T) { checkNoErrors(t, errors) for i := 0; i < don.size; i++ { - snapshot, err := don.orms[i].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + snapshot, err := don.orms[i].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) equal := compareSnapshots(originSnapshot, snapshot) assert.True(t, equal, "oracle %d", i) @@ -188,7 +187,7 @@ func TestS4Integration_HappyDON_4X(t *testing.T) { for o := 0; o < don.size; o++ { rows := generateTestOrmRows(t, 10, time.Minute) for _, row := range rows { - err := don.orms[o].Update(row, pg.WithParentCtx(ctx)) + err := don.orms[o].Update(ctx, row) require.NoError(t, err) } } @@ -197,11 +196,11 @@ func TestS4Integration_HappyDON_4X(t *testing.T) { errors := don.simulateOCR(ctx, 1) checkNoErrors(t, errors) - firstSnapshot, err := don.orms[0].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + firstSnapshot, err := don.orms[0].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) for i := 1; i < don.size; i++ { - snapshot, err := don.orms[i].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + snapshot, err := don.orms[i].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) equal := compareSnapshots(firstSnapshot, snapshot) assert.True(t, equal, "oracle %d", i) @@ -217,10 +216,10 @@ func TestS4Integration_WrongSignature(t *testing.T) { rows := generateTestOrmRows(t, 10, time.Minute) rows[0].Signature = rows[1].Signature for _, row := range rows { - err := don.orms[0].Update(row, pg.WithParentCtx(ctx)) + err := don.orms[0].Update(ctx, row) require.NoError(t, err) } - originSnapshot, err := don.orms[0].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + originSnapshot, err := don.orms[0].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) originSnapshot = filter(originSnapshot, func(row *s4_svc.SnapshotRow) bool { return row.Address.Cmp(rows[0].Address) != 0 || row.SlotId != rows[0].SlotId @@ -232,14 +231,14 @@ func TestS4Integration_WrongSignature(t *testing.T) { checkNoErrors(t, errors) for i := 1; i < don.size; i++ { - snapshot, err2 := don.orms[i].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + snapshot, err2 := don.orms[i].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err2) equal := compareSnapshots(originSnapshot, snapshot) assert.True(t, equal, "oracle %d", i) } // record with a wrong signature must remain unconfirmed - ur, err := don.orms[0].GetUnconfirmedRows(10, pg.WithParentCtx(ctx)) + ur, err := don.orms[0].GetUnconfirmedRows(ctx, 10) require.NoError(t, err) require.Len(t, ur, 1) } @@ -253,10 +252,10 @@ func TestS4Integration_MaxObservations(t *testing.T) { // injecting new records rows := generateTestOrmRows(t, 10, time.Minute) for _, row := range rows { - err := don.orms[0].Update(row, pg.WithParentCtx(ctx)) + err := don.orms[0].Update(ctx, row) require.NoError(t, err) } - originSnapshot, err := don.orms[0].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + originSnapshot, err := don.orms[0].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) // It requires at least two rounds due to MaxObservationEntries = rows / 2 @@ -264,7 +263,7 @@ func TestS4Integration_MaxObservations(t *testing.T) { checkNoErrors(t, errors) for i := 1; i < don.size; i++ { - snapshot, err := don.orms[i].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + snapshot, err := don.orms[i].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) equal := compareSnapshots(originSnapshot, snapshot) assert.True(t, equal, "oracle %d", i) @@ -280,7 +279,7 @@ func TestS4Integration_Expired(t *testing.T) { // injecting expiring records rows := generateTestOrmRows(t, 10, time.Millisecond) for _, row := range rows { - err := don.orms[0].Update(row, pg.WithParentCtx(ctx)) + err := don.orms[0].Update(ctx, row) require.NoError(t, err) } @@ -290,7 +289,7 @@ func TestS4Integration_Expired(t *testing.T) { checkNoErrors(t, errors) for i := 0; i < don.size; i++ { - snapshot, err := don.orms[i].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + snapshot, err := don.orms[i].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) require.Len(t, snapshot, 0) } @@ -305,10 +304,10 @@ func TestS4Integration_NSnapshotShards(t *testing.T) { // injecting lots of new records (to be close to normal address distribution) rows := generateTestOrmRows(t, 1000, time.Minute) for _, row := range rows { - err := don.orms[0].Update(row, pg.WithParentCtx(ctx)) + err := don.orms[0].Update(ctx, row) require.NoError(t, err) } - originSnapshot, err := don.orms[0].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + originSnapshot, err := don.orms[0].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) // this still requires one round, because Observation takes all unconfirmed rows @@ -316,7 +315,7 @@ func TestS4Integration_NSnapshotShards(t *testing.T) { checkNoErrors(t, errors) for i := 1; i < don.size; i++ { - snapshot, err := don.orms[i].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + snapshot, err := don.orms[i].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) equal := compareSnapshots(originSnapshot, snapshot) assert.True(t, equal, "oracle %d", i) @@ -332,7 +331,7 @@ func TestS4Integration_OneNodeOutOfSync(t *testing.T) { rows := generateConfirmedTestOrmRows(t, 10, time.Minute) for o := 0; o < don.size-1; o++ { for _, row := range rows { - err := don.orms[o].Update(row, pg.WithParentCtx(ctx)) + err := don.orms[o].Update(ctx, row) require.NoError(t, err) } } @@ -342,9 +341,9 @@ func TestS4Integration_OneNodeOutOfSync(t *testing.T) { errors := don.simulateOCR(ctx, 4) checkNoErrors(t, errors) - firstSnapshot, err := don.orms[0].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + firstSnapshot, err := don.orms[0].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) - lastSnapshot, err := don.orms[don.size-1].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + lastSnapshot, err := don.orms[don.size-1].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) equal := compareSnapshots(firstSnapshot, lastSnapshot) assert.True(t, equal) @@ -389,7 +388,7 @@ func TestS4Integration_RandomState(t *testing.T) { sig, err := env.Sign(user.privateKey) require.NoError(t, err) row.Signature = sig - err = don.orms[o].Update(row, pg.WithParentCtx(ctx)) + err = don.orms[o].Update(ctx, row) require.NoError(t, err) } } @@ -398,13 +397,13 @@ func TestS4Integration_RandomState(t *testing.T) { errors := don.simulateOCR(ctx, 4) checkNoErrors(t, errors) - firstSnapshot, err := don.orms[0].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + firstSnapshot, err := don.orms[0].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) require.NotEmpty(t, firstSnapshot) checkNoUnconfirmedRows(ctx, t, don.orms[0], 1000) for i := 1; i < don.size; i++ { - snapshot, err := don.orms[i].GetSnapshot(s4_svc.NewFullAddressRange(), pg.WithParentCtx(ctx)) + snapshot, err := don.orms[i].GetSnapshot(ctx, s4_svc.NewFullAddressRange()) require.NoError(t, err) equal := compareSnapshots(firstSnapshot, snapshot) assert.True(t, equal, "oracle %d", i) diff --git a/core/services/ocr2/plugins/s4/plugin.go b/core/services/ocr2/plugins/s4/plugin.go index 2b55ebf3cc5..6976c606045 100644 --- a/core/services/ocr2/plugins/s4/plugin.go +++ b/core/services/ocr2/plugins/s4/plugin.go @@ -12,7 +12,6 @@ import ( "github.com/smartcontractkit/libocr/offchainreporting2plus/types" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/s4" ) @@ -69,7 +68,7 @@ func NewReportingPlugin(logger commontypes.Logger, config *PluginConfig, orm s4. func (c *plugin) Query(ctx context.Context, ts types.ReportTimestamp) (types.Query, error) { promReportingPluginQuery.WithLabelValues(c.config.ProductName).Inc() - snapshot, err := c.orm.GetSnapshot(c.addressRange, pg.WithParentCtx(ctx)) + snapshot, err := c.orm.GetSnapshot(ctx, c.addressRange) if err != nil { return nil, errors.Wrap(err, "failed to GetVersions in Query()") } @@ -111,7 +110,7 @@ func (c *plugin) Observation(ctx context.Context, ts types.ReportTimestamp, quer promReportingPluginObservation.WithLabelValues(c.config.ProductName).Inc() now := time.Now().UTC() - count, err := c.orm.DeleteExpired(c.config.MaxDeleteExpiredEntries, now, pg.WithParentCtx(ctx)) + count, err := c.orm.DeleteExpired(ctx, c.config.MaxDeleteExpiredEntries, now) if err != nil { return nil, errors.Wrap(err, "failed to DeleteExpired in Observation()") } @@ -122,7 +121,7 @@ func (c *plugin) Observation(ctx context.Context, ts types.ReportTimestamp, quer return MarshalRows(convertRows(rows)) } - unconfirmedRows, err := c.orm.GetUnconfirmedRows(c.config.MaxObservationEntries, pg.WithParentCtx(ctx)) + unconfirmedRows, err := c.orm.GetUnconfirmedRows(ctx, c.config.MaxObservationEntries) if err != nil { return nil, errors.Wrap(err, "failed to GetUnconfirmedRows in Observation()") } @@ -138,7 +137,7 @@ func (c *plugin) Observation(ctx context.Context, ts types.ReportTimestamp, quer if err != nil { c.logger.Error("Failed to unmarshal query (likely malformed)", commontypes.LogFields{"err": err}) } else { - snapshot, err := c.orm.GetSnapshot(addressRange, pg.WithParentCtx(ctx)) + snapshot, err := c.orm.GetSnapshot(ctx, addressRange) if err != nil { c.logger.Error("ORM GetSnapshot error", commontypes.LogFields{"err": err}) } else { @@ -178,7 +177,7 @@ func (c *plugin) Observation(ctx context.Context, ts types.ReportTimestamp, quer } for _, k := range toBeAdded { - row, err := c.orm.Get(k.address, k.slotID, pg.WithParentCtx(ctx)) + row, err := c.orm.Get(ctx, k.address, k.slotID) if err == nil { remainingRows = append(remainingRows, row) } else if !errors.Is(err, s4.ErrNotFound) { @@ -283,7 +282,7 @@ func (c *plugin) ShouldAcceptFinalizedReport(ctx context.Context, ts types.Repor continue } - err = c.orm.Update(ormRow, pg.WithParentCtx(ctx)) + err = c.orm.Update(ctx, ormRow) if err != nil && !errors.Is(err, s4.ErrVersionTooLow) { c.logger.Error("Failed to Update a row in ShouldAcceptFinalizedReport()", commontypes.LogFields{"err": err}) continue diff --git a/core/services/ocr2/plugins/s4/plugin_test.go b/core/services/ocr2/plugins/s4/plugin_test.go index b53ab40bfcb..6321b8ce867 100644 --- a/core/services/ocr2/plugins/s4/plugin_test.go +++ b/core/services/ocr2/plugins/s4/plugin_test.go @@ -205,7 +205,7 @@ func TestPlugin_ShouldAcceptFinalizedReport(t *testing.T) { ormRows := make([]*s4_svc.Row, 0) rows := generateTestRows(t, 10, time.Minute) orm.On("Update", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { - updateRow := args.Get(0).(*s4_svc.Row) + updateRow := args.Get(1).(*s4_svc.Row) ormRows = append(ormRows, updateRow) }).Return(nil).Times(10) @@ -344,8 +344,8 @@ func TestPlugin_Observation(t *testing.T) { for _, or := range ormRows { or.Confirmed = false } - orm.On("DeleteExpired", uint(10), mock.Anything, mock.Anything).Return(int64(10), nil).Once() - orm.On("GetUnconfirmedRows", config.MaxObservationEntries, mock.Anything).Return(ormRows, nil).Once() + orm.On("DeleteExpired", mock.Anything, uint(10), mock.Anything, mock.Anything).Return(int64(10), nil).Once() + orm.On("GetUnconfirmedRows", mock.Anything, config.MaxObservationEntries).Return(ormRows, nil).Once() observation, err := plugin.Observation(testutils.Context(t), types.ReportTimestamp{}, []byte{}) assert.NoError(t, err) @@ -370,8 +370,8 @@ func TestPlugin_Observation(t *testing.T) { Confirmed: or.Confirmed, } } - orm.On("DeleteExpired", uint(10), mock.Anything, mock.Anything).Return(int64(10), nil).Once() - orm.On("GetUnconfirmedRows", config.MaxObservationEntries, mock.Anything).Return(ormRows[numUnconfirmed:], nil).Once() + orm.On("DeleteExpired", mock.Anything, uint(10), mock.Anything, mock.Anything).Return(int64(10), nil).Once() + orm.On("GetUnconfirmedRows", mock.Anything, config.MaxObservationEntries).Return(ormRows[numUnconfirmed:], nil).Once() orm.On("GetSnapshot", mock.Anything, mock.Anything).Return(snapshot, nil).Once() snapshotRows := rowsToShapshotRows(ormRows) @@ -388,7 +388,7 @@ func TestPlugin_Observation(t *testing.T) { if i < numHigherVersion { ormRows[i].Version++ snapshot[i].Version++ - orm.On("Get", v.Address, v.SlotId, mock.Anything).Return(ormRows[i], nil).Once() + orm.On("Get", mock.Anything, v.Address, v.SlotId).Return(ormRows[i], nil).Once() } } queryBytes, err := proto.Marshal(query) @@ -447,11 +447,11 @@ func TestPlugin_Observation(t *testing.T) { queryBytes, err := proto.Marshal(query) assert.NoError(t, err) - orm.On("DeleteExpired", uint(10), mock.Anything, mock.Anything).Return(int64(10), nil).Once() - orm.On("GetUnconfirmedRows", config.MaxObservationEntries, mock.Anything).Return([]*s4_svc.Row{}, nil).Once() + orm.On("DeleteExpired", mock.Anything, uint(10), mock.Anything, mock.Anything).Return(int64(10), nil).Once() + orm.On("GetUnconfirmedRows", mock.Anything, config.MaxObservationEntries).Return([]*s4_svc.Row{}, nil).Once() orm.On("GetSnapshot", mock.Anything, mock.Anything).Return(snapshot, nil).Once() - orm.On("Get", snapshot[1].Address, snapshot[1].SlotId, mock.Anything).Return(ormRows[1], nil).Once() - orm.On("Get", snapshot[2].Address, snapshot[2].SlotId, mock.Anything).Return(ormRows[2], nil).Once() + orm.On("Get", mock.Anything, snapshot[1].Address, snapshot[1].SlotId).Return(ormRows[1], nil).Once() + orm.On("Get", mock.Anything, snapshot[2].Address, snapshot[2].SlotId).Return(ormRows[2], nil).Once() observation, err := plugin.Observation(testutils.Context(t), types.ReportTimestamp{}, queryBytes) assert.NoError(t, err) diff --git a/core/services/s4/cached_orm_wrapper.go b/core/services/s4/cached_orm_wrapper.go index 38b9ecba1ca..fe6cb20e3cd 100644 --- a/core/services/s4/cached_orm_wrapper.go +++ b/core/services/s4/cached_orm_wrapper.go @@ -1,6 +1,7 @@ package s4 import ( + "context" "fmt" "math/big" "strings" @@ -10,7 +11,6 @@ import ( ubig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) const ( @@ -40,18 +40,18 @@ func NewCachedORMWrapper(orm ORM, lggr logger.Logger) *CachedORM { } } -func (c CachedORM) Get(address *ubig.Big, slotId uint, qopts ...pg.QOpt) (*Row, error) { - return c.underlayingORM.Get(address, slotId, qopts...) +func (c CachedORM) Get(ctx context.Context, address *ubig.Big, slotId uint) (*Row, error) { + return c.underlayingORM.Get(ctx, address, slotId) } -func (c CachedORM) Update(row *Row, qopts ...pg.QOpt) error { +func (c CachedORM) Update(ctx context.Context, row *Row) error { c.deleteRowFromSnapshotCache(row) - return c.underlayingORM.Update(row, qopts...) + return c.underlayingORM.Update(ctx, row) } -func (c CachedORM) DeleteExpired(limit uint, utcNow time.Time, qopts ...pg.QOpt) (int64, error) { - deletedRows, err := c.underlayingORM.DeleteExpired(limit, utcNow, qopts...) +func (c CachedORM) DeleteExpired(ctx context.Context, limit uint, utcNow time.Time) (int64, error) { + deletedRows, err := c.underlayingORM.DeleteExpired(ctx, limit, utcNow) if err != nil { return 0, err } @@ -63,7 +63,7 @@ func (c CachedORM) DeleteExpired(limit uint, utcNow time.Time, qopts ...pg.QOpt) return deletedRows, nil } -func (c CachedORM) GetSnapshot(addressRange *AddressRange, qopts ...pg.QOpt) ([]*SnapshotRow, error) { +func (c CachedORM) GetSnapshot(ctx context.Context, addressRange *AddressRange) ([]*SnapshotRow, error) { key := fmt.Sprintf("%s_%s_%s", getSnapshotCachePrefix, addressRange.MinAddress.String(), addressRange.MaxAddress.String()) cached, found := c.cache.Get(key) @@ -72,7 +72,7 @@ func (c CachedORM) GetSnapshot(addressRange *AddressRange, qopts ...pg.QOpt) ([] } c.lggr.Debug("Snapshot not found in cache, fetching it from underlaying implementation") - data, err := c.underlayingORM.GetSnapshot(addressRange, qopts...) + data, err := c.underlayingORM.GetSnapshot(ctx, addressRange) if err != nil { return nil, err } @@ -81,8 +81,8 @@ func (c CachedORM) GetSnapshot(addressRange *AddressRange, qopts ...pg.QOpt) ([] return data, nil } -func (c CachedORM) GetUnconfirmedRows(limit uint, qopts ...pg.QOpt) ([]*Row, error) { - return c.underlayingORM.GetUnconfirmedRows(limit, qopts...) +func (c CachedORM) GetUnconfirmedRows(ctx context.Context, limit uint) ([]*Row, error) { + return c.underlayingORM.GetUnconfirmedRows(ctx, limit) } // deleteRowFromSnapshotCache will clean the cache for every snapshot that would involve a given row diff --git a/core/services/s4/cached_orm_wrapper_test.go b/core/services/s4/cached_orm_wrapper_test.go index 6f6ac298557..5b94ce3b253 100644 --- a/core/services/s4/cached_orm_wrapper_test.go +++ b/core/services/s4/cached_orm_wrapper_test.go @@ -8,6 +8,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" @@ -21,11 +22,12 @@ import ( func TestGetSnapshotEmpty(t *testing.T) { t.Run("OK-no_rows", func(t *testing.T) { + ctx := testutils.Context(t) psqlORM := setupORM(t, "test") lggr := logger.TestLogger(t) orm := s4.NewCachedORMWrapper(psqlORM, lggr) - rows, err := orm.GetSnapshot(s4.NewFullAddressRange()) + rows, err := orm.GetSnapshot(ctx, s4.NewFullAddressRange()) assert.NoError(t, err) assert.Empty(t, rows) }) @@ -33,23 +35,24 @@ func TestGetSnapshotEmpty(t *testing.T) { func TestGetSnapshotCacheFilled(t *testing.T) { t.Run("OK_with_rows_already_cached", func(t *testing.T) { + ctx := testutils.Context(t) rows := generateTestSnapshotRows(t, 100) fullAddressRange := s4.NewFullAddressRange() lggr := logger.TestLogger(t) underlayingORM := mocks.NewORM(t) - underlayingORM.On("GetSnapshot", fullAddressRange).Return(rows, nil).Once() + underlayingORM.On("GetSnapshot", mock.Anything, fullAddressRange).Return(rows, nil).Once() orm := s4.NewCachedORMWrapper(underlayingORM, lggr) // first call will go to the underlaying orm implementation to fill the cache - first_snapshot, err := orm.GetSnapshot(fullAddressRange) + first_snapshot, err := orm.GetSnapshot(ctx, fullAddressRange) assert.NoError(t, err) assert.Equal(t, len(rows), len(first_snapshot)) // on the second call, the results will come from the cache, if not the mock will return an error because of .Once() - cache_snapshot, err := orm.GetSnapshot(fullAddressRange) + cache_snapshot, err := orm.GetSnapshot(ctx, fullAddressRange) assert.NoError(t, err) assert.Equal(t, len(rows), len(cache_snapshot)) @@ -75,23 +78,24 @@ func TestGetSnapshotCacheFilled(t *testing.T) { func TestUpdateInvalidatesSnapshotCache(t *testing.T) { t.Run("OK-GetSnapshot_cache_invalidated_after_update", func(t *testing.T) { + ctx := testutils.Context(t) rows := generateTestSnapshotRows(t, 100) fullAddressRange := s4.NewFullAddressRange() lggr := logger.TestLogger(t) underlayingORM := mocks.NewORM(t) - underlayingORM.On("GetSnapshot", fullAddressRange).Return(rows, nil).Once() + underlayingORM.On("GetSnapshot", mock.Anything, fullAddressRange).Return(rows, nil).Once() orm := s4.NewCachedORMWrapper(underlayingORM, lggr) // first call will go to the underlaying orm implementation to fill the cache - first_snapshot, err := orm.GetSnapshot(fullAddressRange) + first_snapshot, err := orm.GetSnapshot(ctx, fullAddressRange) assert.NoError(t, err) assert.Equal(t, len(rows), len(first_snapshot)) // on the second call, the results will come from the cache, if not the mock will return an error because of .Once() - cache_snapshot, err := orm.GetSnapshot(fullAddressRange) + cache_snapshot, err := orm.GetSnapshot(ctx, fullAddressRange) assert.NoError(t, err) assert.Equal(t, len(rows), len(cache_snapshot)) @@ -105,18 +109,19 @@ func TestUpdateInvalidatesSnapshotCache(t *testing.T) { Confirmed: true, Signature: cltest.MustRandomBytes(t, 32), } - underlayingORM.On("Update", row).Return(nil).Once() - err = orm.Update(row) + underlayingORM.On("Update", mock.Anything, row).Return(nil).Once() + err = orm.Update(ctx, row) assert.NoError(t, err) // given the cache was invalidated this request will reach the underlaying orm implementation - underlayingORM.On("GetSnapshot", fullAddressRange).Return(rows, nil).Once() - third_snapshot, err := orm.GetSnapshot(fullAddressRange) + underlayingORM.On("GetSnapshot", mock.Anything, fullAddressRange).Return(rows, nil).Once() + third_snapshot, err := orm.GetSnapshot(ctx, fullAddressRange) assert.NoError(t, err) assert.Equal(t, len(rows), len(third_snapshot)) }) t.Run("OK-GetSnapshot_cache_not_invalidated_after_update", func(t *testing.T) { + ctx := testutils.Context(t) rows := generateTestSnapshotRows(t, 5) addressRange := &s4.AddressRange{ @@ -126,17 +131,17 @@ func TestUpdateInvalidatesSnapshotCache(t *testing.T) { lggr := logger.TestLogger(t) underlayingORM := mocks.NewORM(t) - underlayingORM.On("GetSnapshot", addressRange).Return(rows, nil).Once() + underlayingORM.On("GetSnapshot", mock.Anything, addressRange).Return(rows, nil).Once() orm := s4.NewCachedORMWrapper(underlayingORM, lggr) // first call will go to the underlaying orm implementation to fill the cache - first_snapshot, err := orm.GetSnapshot(addressRange) + first_snapshot, err := orm.GetSnapshot(ctx, addressRange) assert.NoError(t, err) assert.Equal(t, len(rows), len(first_snapshot)) // on the second call, the results will come from the cache, if not the mock will return an error because of .Once() - cache_snapshot, err := orm.GetSnapshot(addressRange) + cache_snapshot, err := orm.GetSnapshot(ctx, addressRange) assert.NoError(t, err) assert.Equal(t, len(rows), len(cache_snapshot)) @@ -151,12 +156,12 @@ func TestUpdateInvalidatesSnapshotCache(t *testing.T) { Confirmed: true, Signature: cltest.MustRandomBytes(t, 32), } - underlayingORM.On("Update", row).Return(nil).Once() - err = orm.Update(row) + underlayingORM.On("Update", mock.Anything, row).Return(nil).Once() + err = orm.Update(ctx, row) assert.NoError(t, err) // given the cache was not invalidated this request wont reach the underlaying orm implementation - third_snapshot, err := orm.GetSnapshot(addressRange) + third_snapshot, err := orm.GetSnapshot(ctx, addressRange) assert.NoError(t, err) assert.Equal(t, len(rows), len(third_snapshot)) }) @@ -169,24 +174,26 @@ func TestGet(t *testing.T) { lggr := logger.TestLogger(t) t.Run("OK-Get_underlaying_ORM_returns_a_row", func(t *testing.T) { + ctx := testutils.Context(t) underlayingORM := mocks.NewORM(t) expectedRow := &s4.Row{ Address: address, SlotId: slotID, } - underlayingORM.On("Get", address, slotID).Return(expectedRow, nil).Once() + underlayingORM.On("Get", mock.Anything, address, slotID).Return(expectedRow, nil).Once() orm := s4.NewCachedORMWrapper(underlayingORM, lggr) - row, err := orm.Get(address, slotID) + row, err := orm.Get(ctx, address, slotID) require.NoError(t, err) require.Equal(t, expectedRow, row) }) t.Run("NOK-Get_underlaying_ORM_returns_an_error", func(t *testing.T) { + ctx := testutils.Context(t) underlayingORM := mocks.NewORM(t) - underlayingORM.On("Get", address, slotID).Return(nil, fmt.Errorf("some_error")).Once() + underlayingORM.On("Get", mock.Anything, address, slotID).Return(nil, fmt.Errorf("some_error")).Once() orm := s4.NewCachedORMWrapper(underlayingORM, lggr) - row, err := orm.Get(address, slotID) + row, err := orm.Get(ctx, address, slotID) require.Nil(t, row) require.EqualError(t, err, "some_error") }) @@ -199,22 +206,24 @@ func TestDeletedExpired(t *testing.T) { lggr := logger.TestLogger(t) t.Run("OK-DeletedExpired_underlaying_ORM_returns_a_row", func(t *testing.T) { + ctx := testutils.Context(t) var expectedDeleted int64 = 10 underlayingORM := mocks.NewORM(t) - underlayingORM.On("DeleteExpired", limit, now).Return(expectedDeleted, nil).Once() + underlayingORM.On("DeleteExpired", mock.Anything, limit, now).Return(expectedDeleted, nil).Once() orm := s4.NewCachedORMWrapper(underlayingORM, lggr) - actualDeleted, err := orm.DeleteExpired(limit, now) + actualDeleted, err := orm.DeleteExpired(ctx, limit, now) require.NoError(t, err) require.Equal(t, expectedDeleted, actualDeleted) }) t.Run("NOK-DeletedExpired_underlaying_ORM_returns_an_error", func(t *testing.T) { + ctx := testutils.Context(t) var expectedDeleted int64 underlayingORM := mocks.NewORM(t) - underlayingORM.On("DeleteExpired", limit, now).Return(expectedDeleted, fmt.Errorf("some_error")).Once() + underlayingORM.On("DeleteExpired", mock.Anything, limit, now).Return(expectedDeleted, fmt.Errorf("some_error")).Once() orm := s4.NewCachedORMWrapper(underlayingORM, lggr) - actualDeleted, err := orm.DeleteExpired(limit, now) + actualDeleted, err := orm.DeleteExpired(ctx, limit, now) require.EqualError(t, err, "some_error") require.Equal(t, expectedDeleted, actualDeleted) }) @@ -226,6 +235,7 @@ func TestGetUnconfirmedRows(t *testing.T) { lggr := logger.TestLogger(t) t.Run("OK-GetUnconfirmedRows_underlaying_ORM_returns_a_row", func(t *testing.T) { + ctx := testutils.Context(t) address := big.New(testutils.NewAddress().Big()) var slotID uint = 1 @@ -234,19 +244,20 @@ func TestGetUnconfirmedRows(t *testing.T) { SlotId: slotID, }} underlayingORM := mocks.NewORM(t) - underlayingORM.On("GetUnconfirmedRows", limit).Return(expectedRow, nil).Once() + underlayingORM.On("GetUnconfirmedRows", mock.Anything, limit).Return(expectedRow, nil).Once() orm := s4.NewCachedORMWrapper(underlayingORM, lggr) - actualRow, err := orm.GetUnconfirmedRows(limit) + actualRow, err := orm.GetUnconfirmedRows(ctx, limit) require.NoError(t, err) require.Equal(t, expectedRow, actualRow) }) t.Run("NOK-GetUnconfirmedRows_underlaying_ORM_returns_an_error", func(t *testing.T) { + ctx := testutils.Context(t) underlayingORM := mocks.NewORM(t) - underlayingORM.On("GetUnconfirmedRows", limit).Return(nil, fmt.Errorf("some_error")).Once() + underlayingORM.On("GetUnconfirmedRows", mock.Anything, limit).Return(nil, fmt.Errorf("some_error")).Once() orm := s4.NewCachedORMWrapper(underlayingORM, lggr) - actualRow, err := orm.GetUnconfirmedRows(limit) + actualRow, err := orm.GetUnconfirmedRows(ctx, limit) require.Nil(t, actualRow) require.EqualError(t, err, "some_error") }) diff --git a/core/services/s4/in_memory_orm.go b/core/services/s4/in_memory_orm.go index 28b50ce430c..723f8820999 100644 --- a/core/services/s4/in_memory_orm.go +++ b/core/services/s4/in_memory_orm.go @@ -1,12 +1,12 @@ package s4 import ( + "context" "sort" "sync" "time" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) type key struct { @@ -32,7 +32,7 @@ func NewInMemoryORM() ORM { } } -func (o *inMemoryOrm) Get(address *big.Big, slotId uint, qopts ...pg.QOpt) (*Row, error) { +func (o *inMemoryOrm) Get(ctx context.Context, address *big.Big, slotId uint) (*Row, error) { o.mu.RLock() defer o.mu.RUnlock() @@ -47,7 +47,7 @@ func (o *inMemoryOrm) Get(address *big.Big, slotId uint, qopts ...pg.QOpt) (*Row return mrow.Row.Clone(), nil } -func (o *inMemoryOrm) Update(row *Row, qopts ...pg.QOpt) error { +func (o *inMemoryOrm) Update(ctx context.Context, row *Row) error { o.mu.Lock() defer o.mu.Unlock() @@ -74,7 +74,7 @@ func (o *inMemoryOrm) Update(row *Row, qopts ...pg.QOpt) error { return nil } -func (o *inMemoryOrm) DeleteExpired(limit uint, now time.Time, qopts ...pg.QOpt) (int64, error) { +func (o *inMemoryOrm) DeleteExpired(ctx context.Context, limit uint, now time.Time) (int64, error) { o.mu.Lock() defer o.mu.Unlock() @@ -94,7 +94,7 @@ func (o *inMemoryOrm) DeleteExpired(limit uint, now time.Time, qopts ...pg.QOpt) return int64(len(queue)), nil } -func (o *inMemoryOrm) GetSnapshot(addressRange *AddressRange, qopts ...pg.QOpt) ([]*SnapshotRow, error) { +func (o *inMemoryOrm) GetSnapshot(ctx context.Context, _ *AddressRange) ([]*SnapshotRow, error) { o.mu.RLock() defer o.mu.RUnlock() @@ -115,7 +115,7 @@ func (o *inMemoryOrm) GetSnapshot(addressRange *AddressRange, qopts ...pg.QOpt) return rows, nil } -func (o *inMemoryOrm) GetUnconfirmedRows(limit uint, qopts ...pg.QOpt) ([]*Row, error) { +func (o *inMemoryOrm) GetUnconfirmedRows(ctx context.Context, limit uint) ([]*Row, error) { o.mu.RLock() defer o.mu.RUnlock() diff --git a/core/services/s4/in_memory_orm_test.go b/core/services/s4/in_memory_orm_test.go index 318db5f1a44..db4f73ba1ef 100644 --- a/core/services/s4/in_memory_orm_test.go +++ b/core/services/s4/in_memory_orm_test.go @@ -33,33 +33,36 @@ func TestInMemoryORM(t *testing.T) { orm := s4.NewInMemoryORM() t.Run("row not found", func(t *testing.T) { - _, err := orm.Get(big.New(address.Big()), slotId) + ctx := testutils.Context(t) + _, err := orm.Get(ctx, big.New(address.Big()), slotId) assert.ErrorIs(t, err, s4.ErrNotFound) }) t.Run("insert and get", func(t *testing.T) { - err := orm.Update(row) + ctx := testutils.Context(t) + err := orm.Update(ctx, row) assert.NoError(t, err) - e, err := orm.Get(big.New(address.Big()), slotId) + e, err := orm.Get(ctx, big.New(address.Big()), slotId) assert.NoError(t, err) assert.Equal(t, row, e) }) t.Run("update and get", func(t *testing.T) { + ctx := testutils.Context(t) row.Version = 5 - err := orm.Update(row) + err := orm.Update(ctx, row) assert.NoError(t, err) // unconfirmed row requires greater version - err = orm.Update(row) + err = orm.Update(ctx, row) assert.ErrorIs(t, err, s4.ErrVersionTooLow) row.Confirmed = true - err = orm.Update(row) + err = orm.Update(ctx, row) assert.NoError(t, err) - e, err := orm.Get(big.New(address.Big()), slotId) + e, err := orm.Get(ctx, big.New(address.Big()), slotId) assert.NoError(t, err) assert.Equal(t, row, e) }) @@ -67,6 +70,7 @@ func TestInMemoryORM(t *testing.T) { func TestInMemoryORM_DeleteExpired(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := s4.NewInMemoryORM() baseTime := time.Now().Add(time.Minute).UTC() @@ -84,22 +88,23 @@ func TestInMemoryORM_DeleteExpired(t *testing.T) { Confirmed: false, Signature: []byte{}, } - err := orm.Update(row) + err := orm.Update(ctx, row) assert.NoError(t, err) } deadline := baseTime.Add(100 * time.Second) - count, err := orm.DeleteExpired(200, deadline) + count, err := orm.DeleteExpired(ctx, 200, deadline) assert.NoError(t, err) assert.Equal(t, int64(100), count) - rows, err := orm.GetUnconfirmedRows(200) + rows, err := orm.GetUnconfirmedRows(ctx, 200) assert.NoError(t, err) assert.Len(t, rows, 156) } func TestInMemoryORM_GetUnconfirmedRows(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := s4.NewInMemoryORM() expiration := time.Now().Add(100 * time.Second).UnixMilli() @@ -117,18 +122,19 @@ func TestInMemoryORM_GetUnconfirmedRows(t *testing.T) { Confirmed: i >= 100, Signature: []byte{}, } - err := orm.Update(row) + err := orm.Update(ctx, row) assert.NoError(t, err) time.Sleep(time.Millisecond) } - rows, err := orm.GetUnconfirmedRows(100) + rows, err := orm.GetUnconfirmedRows(ctx, 100) assert.NoError(t, err) assert.Len(t, rows, 100) } func TestInMemoryORM_GetSnapshot(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := s4.NewInMemoryORM() expiration := time.Now().Add(100 * time.Second).UnixMilli() @@ -147,11 +153,11 @@ func TestInMemoryORM_GetSnapshot(t *testing.T) { Confirmed: i >= 100, Signature: []byte{}, } - err := orm.Update(row) + err := orm.Update(ctx, row) assert.NoError(t, err) } - rows, err := orm.GetSnapshot(s4.NewFullAddressRange()) + rows, err := orm.GetSnapshot(ctx, s4.NewFullAddressRange()) assert.NoError(t, err) assert.Len(t, rows, n) diff --git a/core/services/s4/mocks/orm.go b/core/services/s4/mocks/orm.go index 3b8cac8e76d..4a5d7fa992d 100644 --- a/core/services/s4/mocks/orm.go +++ b/core/services/s4/mocks/orm.go @@ -3,10 +3,11 @@ package mocks import ( + context "context" + big "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" - mock "github.com/stretchr/testify/mock" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" + mock "github.com/stretchr/testify/mock" s4 "github.com/smartcontractkit/chainlink/v2/core/services/s4" @@ -18,16 +19,9 @@ type ORM struct { mock.Mock } -// DeleteExpired provides a mock function with given fields: limit, utcNow, qopts -func (_m *ORM) DeleteExpired(limit uint, utcNow time.Time, qopts ...pg.QOpt) (int64, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, limit, utcNow) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// DeleteExpired provides a mock function with given fields: ctx, limit, utcNow +func (_m *ORM) DeleteExpired(ctx context.Context, limit uint, utcNow time.Time) (int64, error) { + ret := _m.Called(ctx, limit, utcNow) if len(ret) == 0 { panic("no return value specified for DeleteExpired") @@ -35,17 +29,17 @@ func (_m *ORM) DeleteExpired(limit uint, utcNow time.Time, qopts ...pg.QOpt) (in var r0 int64 var r1 error - if rf, ok := ret.Get(0).(func(uint, time.Time, ...pg.QOpt) (int64, error)); ok { - return rf(limit, utcNow, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint, time.Time) (int64, error)); ok { + return rf(ctx, limit, utcNow) } - if rf, ok := ret.Get(0).(func(uint, time.Time, ...pg.QOpt) int64); ok { - r0 = rf(limit, utcNow, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint, time.Time) int64); ok { + r0 = rf(ctx, limit, utcNow) } else { r0 = ret.Get(0).(int64) } - if rf, ok := ret.Get(1).(func(uint, time.Time, ...pg.QOpt) error); ok { - r1 = rf(limit, utcNow, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, uint, time.Time) error); ok { + r1 = rf(ctx, limit, utcNow) } else { r1 = ret.Error(1) } @@ -53,16 +47,9 @@ func (_m *ORM) DeleteExpired(limit uint, utcNow time.Time, qopts ...pg.QOpt) (in return r0, r1 } -// Get provides a mock function with given fields: address, slotId, qopts -func (_m *ORM) Get(address *big.Big, slotId uint, qopts ...pg.QOpt) (*s4.Row, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, address, slotId) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// Get provides a mock function with given fields: ctx, address, slotId +func (_m *ORM) Get(ctx context.Context, address *big.Big, slotId uint) (*s4.Row, error) { + ret := _m.Called(ctx, address, slotId) if len(ret) == 0 { panic("no return value specified for Get") @@ -70,19 +57,19 @@ func (_m *ORM) Get(address *big.Big, slotId uint, qopts ...pg.QOpt) (*s4.Row, er var r0 *s4.Row var r1 error - if rf, ok := ret.Get(0).(func(*big.Big, uint, ...pg.QOpt) (*s4.Row, error)); ok { - return rf(address, slotId, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *big.Big, uint) (*s4.Row, error)); ok { + return rf(ctx, address, slotId) } - if rf, ok := ret.Get(0).(func(*big.Big, uint, ...pg.QOpt) *s4.Row); ok { - r0 = rf(address, slotId, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *big.Big, uint) *s4.Row); ok { + r0 = rf(ctx, address, slotId) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*s4.Row) } } - if rf, ok := ret.Get(1).(func(*big.Big, uint, ...pg.QOpt) error); ok { - r1 = rf(address, slotId, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, *big.Big, uint) error); ok { + r1 = rf(ctx, address, slotId) } else { r1 = ret.Error(1) } @@ -90,16 +77,9 @@ func (_m *ORM) Get(address *big.Big, slotId uint, qopts ...pg.QOpt) (*s4.Row, er return r0, r1 } -// GetSnapshot provides a mock function with given fields: addressRange, qopts -func (_m *ORM) GetSnapshot(addressRange *s4.AddressRange, qopts ...pg.QOpt) ([]*s4.SnapshotRow, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, addressRange) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// GetSnapshot provides a mock function with given fields: ctx, addressRange +func (_m *ORM) GetSnapshot(ctx context.Context, addressRange *s4.AddressRange) ([]*s4.SnapshotRow, error) { + ret := _m.Called(ctx, addressRange) if len(ret) == 0 { panic("no return value specified for GetSnapshot") @@ -107,19 +87,19 @@ func (_m *ORM) GetSnapshot(addressRange *s4.AddressRange, qopts ...pg.QOpt) ([]* var r0 []*s4.SnapshotRow var r1 error - if rf, ok := ret.Get(0).(func(*s4.AddressRange, ...pg.QOpt) ([]*s4.SnapshotRow, error)); ok { - return rf(addressRange, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *s4.AddressRange) ([]*s4.SnapshotRow, error)); ok { + return rf(ctx, addressRange) } - if rf, ok := ret.Get(0).(func(*s4.AddressRange, ...pg.QOpt) []*s4.SnapshotRow); ok { - r0 = rf(addressRange, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *s4.AddressRange) []*s4.SnapshotRow); ok { + r0 = rf(ctx, addressRange) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*s4.SnapshotRow) } } - if rf, ok := ret.Get(1).(func(*s4.AddressRange, ...pg.QOpt) error); ok { - r1 = rf(addressRange, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, *s4.AddressRange) error); ok { + r1 = rf(ctx, addressRange) } else { r1 = ret.Error(1) } @@ -127,16 +107,9 @@ func (_m *ORM) GetSnapshot(addressRange *s4.AddressRange, qopts ...pg.QOpt) ([]* return r0, r1 } -// GetUnconfirmedRows provides a mock function with given fields: limit, qopts -func (_m *ORM) GetUnconfirmedRows(limit uint, qopts ...pg.QOpt) ([]*s4.Row, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, limit) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// GetUnconfirmedRows provides a mock function with given fields: ctx, limit +func (_m *ORM) GetUnconfirmedRows(ctx context.Context, limit uint) ([]*s4.Row, error) { + ret := _m.Called(ctx, limit) if len(ret) == 0 { panic("no return value specified for GetUnconfirmedRows") @@ -144,19 +117,19 @@ func (_m *ORM) GetUnconfirmedRows(limit uint, qopts ...pg.QOpt) ([]*s4.Row, erro var r0 []*s4.Row var r1 error - if rf, ok := ret.Get(0).(func(uint, ...pg.QOpt) ([]*s4.Row, error)); ok { - return rf(limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint) ([]*s4.Row, error)); ok { + return rf(ctx, limit) } - if rf, ok := ret.Get(0).(func(uint, ...pg.QOpt) []*s4.Row); ok { - r0 = rf(limit, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, uint) []*s4.Row); ok { + r0 = rf(ctx, limit) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]*s4.Row) } } - if rf, ok := ret.Get(1).(func(uint, ...pg.QOpt) error); ok { - r1 = rf(limit, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, uint) error); ok { + r1 = rf(ctx, limit) } else { r1 = ret.Error(1) } @@ -164,24 +137,17 @@ func (_m *ORM) GetUnconfirmedRows(limit uint, qopts ...pg.QOpt) ([]*s4.Row, erro return r0, r1 } -// Update provides a mock function with given fields: row, qopts -func (_m *ORM) Update(row *s4.Row, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, row) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// Update provides a mock function with given fields: ctx, row +func (_m *ORM) Update(ctx context.Context, row *s4.Row) error { + ret := _m.Called(ctx, row) if len(ret) == 0 { panic("no return value specified for Update") } var r0 error - if rf, ok := ret.Get(0).(func(*s4.Row, ...pg.QOpt) error); ok { - r0 = rf(row, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *s4.Row) error); ok { + r0 = rf(ctx, row) } else { r0 = ret.Error(0) } diff --git a/core/services/s4/orm.go b/core/services/s4/orm.go index 4d3cee9312a..952d8a33b24 100644 --- a/core/services/s4/orm.go +++ b/core/services/s4/orm.go @@ -1,10 +1,10 @@ package s4 import ( + "context" "time" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) // Row represents a data row persisted by ORM. @@ -36,26 +36,26 @@ type ORM interface { // Get reads a row for the given address and slotId combination. // If such row does not exist, ErrNotFound is returned. // There is no filter on Expiration. - Get(address *big.Big, slotId uint, qopts ...pg.QOpt) (*Row, error) + Get(ctx context.Context, address *big.Big, slotId uint) (*Row, error) // Update inserts or updates the row identified by (Address, SlotId) pair. // When updating, the new row must have greater or equal version, // otherwise ErrVersionTooLow is returned. // UpdatedAt field value is ignored. - Update(row *Row, qopts ...pg.QOpt) error + Update(ctx context.Context, row *Row) error // DeleteExpired deletes any entries having Expiration < utcNow, // up to the given limit. // Returns the number of deleted rows. - DeleteExpired(limit uint, utcNow time.Time, qopts ...pg.QOpt) (int64, error) + DeleteExpired(ctx context.Context, limit uint, utcNow time.Time) (int64, error) // GetSnapshot selects all non-expired row versions for the given addresses range. // For the full address range, use NewFullAddressRange(). - GetSnapshot(addressRange *AddressRange, qopts ...pg.QOpt) ([]*SnapshotRow, error) + GetSnapshot(ctx context.Context, addressRange *AddressRange) ([]*SnapshotRow, error) // GetUnconfirmedRows selects all non-expired, non-confirmed rows ordered by UpdatedAt. // The number of returned rows is limited to the given limit. - GetUnconfirmedRows(limit uint, qopts ...pg.QOpt) ([]*Row, error) + GetUnconfirmedRows(ctx context.Context, limit uint) ([]*Row, error) } func (r Row) Clone() *Row { diff --git a/core/services/s4/postgres_orm.go b/core/services/s4/postgres_orm.go index 1f92f2e1281..3d271e543d7 100644 --- a/core/services/s4/postgres_orm.go +++ b/core/services/s4/postgres_orm.go @@ -1,16 +1,15 @@ package s4 import ( + "context" "database/sql" "fmt" "time" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" - "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" - - "github.com/jmoiron/sqlx" "github.com/pkg/errors" + + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" ) const ( @@ -19,28 +18,27 @@ const ( ) type orm struct { - q pg.Q + ds sqlutil.DataSource tableName string namespace string } var _ ORM = (*orm)(nil) -func NewPostgresORM(db *sqlx.DB, lggr logger.Logger, cfg pg.QConfig, tableName, namespace string) ORM { +func NewPostgresORM(ds sqlutil.DataSource, tableName, namespace string) ORM { return &orm{ - q: pg.NewQ(db, lggr, cfg), + ds: ds, tableName: fmt.Sprintf(`"%s".%s`, s4PostgresSchema, tableName), namespace: namespace, } } -func (o orm) Get(address *big.Big, slotId uint, qopts ...pg.QOpt) (*Row, error) { +func (o *orm) Get(ctx context.Context, address *big.Big, slotId uint) (*Row, error) { row := &Row{} - q := o.q.WithOpts(qopts...) stmt := fmt.Sprintf(`SELECT address, slot_id, version, expiration, confirmed, payload, signature FROM %s WHERE namespace=$1 AND address=$2 AND slot_id=$3;`, o.tableName) - if err := q.Get(row, stmt, o.namespace, address, slotId); err != nil { + if err := o.ds.GetContext(ctx, row, stmt, o.namespace, address, slotId); err != nil { if errors.Is(err, sql.ErrNoRows) { err = ErrNotFound } @@ -49,9 +47,7 @@ WHERE namespace=$1 AND address=$2 AND slot_id=$3;`, o.tableName) return row, nil } -func (o orm) Update(row *Row, qopts ...pg.QOpt) error { - q := o.q.WithOpts(qopts...) - +func (o *orm) Update(ctx context.Context, row *Row) error { // This query inserts or updates a row, depending on whether the version is higher than the existing one. // We only allow the same version when the row is confirmed. // We never transition back from unconfirmed to confirmed state. @@ -67,31 +63,28 @@ updated_at = NOW() WHERE (t.version < EXCLUDED.version) OR (t.version <= EXCLUDED.version AND EXCLUDED.confirmed IS TRUE) RETURNING id;`, o.tableName) var id uint64 - err := q.Get(&id, stmt, o.namespace, row.Address, row.SlotId, row.Version, row.Expiration, row.Confirmed, row.Payload, row.Signature) + err := o.ds.GetContext(ctx, &id, stmt, o.namespace, row.Address, row.SlotId, row.Version, row.Expiration, row.Confirmed, row.Payload, row.Signature) if errors.Is(err, sql.ErrNoRows) { return ErrVersionTooLow } return err } -func (o orm) DeleteExpired(limit uint, utcNow time.Time, qopts ...pg.QOpt) (int64, error) { - q := o.q.WithOpts(qopts...) - +func (o *orm) DeleteExpired(ctx context.Context, limit uint, utcNow time.Time) (int64, error) { with := fmt.Sprintf(`WITH rows AS (SELECT id FROM %s WHERE namespace = $1 AND expiration < $2 LIMIT $3)`, o.tableName) stmt := fmt.Sprintf(`%s DELETE FROM %s WHERE id IN (SELECT id FROM rows);`, with, o.tableName) - result, err := q.Exec(stmt, o.namespace, utcNow.UnixMilli(), limit) + result, err := o.ds.ExecContext(ctx, stmt, o.namespace, utcNow.UnixMilli(), limit) if err != nil { return 0, err } return result.RowsAffected() } -func (o orm) GetSnapshot(addressRange *AddressRange, qopts ...pg.QOpt) ([]*SnapshotRow, error) { - q := o.q.WithOpts(qopts...) +func (o *orm) GetSnapshot(ctx context.Context, addressRange *AddressRange) ([]*SnapshotRow, error) { rows := make([]*SnapshotRow, 0) stmt := fmt.Sprintf(`SELECT address, slot_id, version, expiration, confirmed, octet_length(payload) AS payload_size FROM %s WHERE namespace = $1 AND address >= $2 AND address <= $3;`, o.tableName) - if err := q.Select(&rows, stmt, o.namespace, addressRange.MinAddress, addressRange.MaxAddress); err != nil { + if err := o.ds.SelectContext(ctx, &rows, stmt, o.namespace, addressRange.MinAddress, addressRange.MaxAddress); err != nil { if !errors.Is(err, sql.ErrNoRows) { return nil, err } @@ -99,13 +92,12 @@ func (o orm) GetSnapshot(addressRange *AddressRange, qopts ...pg.QOpt) ([]*Snaps return rows, nil } -func (o orm) GetUnconfirmedRows(limit uint, qopts ...pg.QOpt) ([]*Row, error) { - q := o.q.WithOpts(qopts...) +func (o *orm) GetUnconfirmedRows(ctx context.Context, limit uint) ([]*Row, error) { rows := make([]*Row, 0) stmt := fmt.Sprintf(`SELECT address, slot_id, version, expiration, confirmed, payload, signature FROM %s WHERE namespace = $1 AND confirmed IS FALSE ORDER BY updated_at LIMIT $2;`, o.tableName) - if err := q.Select(&rows, stmt, o.namespace, limit); err != nil { + if err := o.ds.SelectContext(ctx, &rows, stmt, o.namespace, limit); err != nil { if !errors.Is(err, sql.ErrNoRows) { return nil, err } diff --git a/core/services/s4/postgres_orm_test.go b/core/services/s4/postgres_orm_test.go index d26f082ce5b..660002a2e3b 100644 --- a/core/services/s4/postgres_orm_test.go +++ b/core/services/s4/postgres_orm_test.go @@ -10,7 +10,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" - "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/s4" "github.com/stretchr/testify/assert" @@ -20,8 +19,7 @@ func setupORM(t *testing.T, namespace string) s4.ORM { t.Helper() db := pgtest.NewSqlxDB(t) - lggr := logger.TestLogger(t) - orm := s4.NewPostgresORM(db, lggr, pgtest.NewQConfig(true), s4.SharedTableName, namespace) + orm := s4.NewPostgresORM(db, s4.SharedTableName, namespace) t.Cleanup(func() { assert.NoError(t, db.Close()) @@ -59,64 +57,67 @@ func TestNewPostgresOrm(t *testing.T) { func TestPostgresORM_UpdateAndGet(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t, "test") rows := generateTestRows(t, 10) for _, row := range rows { - err := orm.Update(row) + err := orm.Update(ctx, row) assert.NoError(t, err) row.Version++ - err = orm.Update(row) + err = orm.Update(ctx, row) assert.NoError(t, err) - err = orm.Update(row) + err = orm.Update(ctx, row) if !row.Confirmed { assert.ErrorIs(t, err, s4.ErrVersionTooLow) } } for _, row := range rows { - gotRow, err := orm.Get(row.Address, row.SlotId) + gotRow, err := orm.Get(ctx, row.Address, row.SlotId) assert.NoError(t, err) assert.Equal(t, row, gotRow) } rows = generateTestRows(t, 1) - _, err := orm.Get(rows[0].Address, rows[0].SlotId) + _, err := orm.Get(ctx, rows[0].Address, rows[0].SlotId) assert.ErrorIs(t, err, s4.ErrNotFound) } func TestPostgresORM_UpdateSimpleFlow(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t, "test") row := generateTestRows(t, 1)[0] // user sends a new version - assert.NoError(t, orm.Update(row)) + assert.NoError(t, orm.Update(ctx, row)) // OCR round confirms it row.Confirmed = true - assert.NoError(t, orm.Update(row)) + assert.NoError(t, orm.Update(ctx, row)) // user sends a higher version (unconfirmed) row.Version++ row.Confirmed = false - assert.NoError(t, orm.Update(row)) + assert.NoError(t, orm.Update(ctx, row)) // and again, before OCR has a chance to confirm row.Version++ - assert.NoError(t, orm.Update(row)) + assert.NoError(t, orm.Update(ctx, row)) // user tries to send a lower version row.Version-- - assert.Error(t, orm.Update(row)) + assert.Error(t, orm.Update(ctx, row)) } func TestPostgresORM_DeleteExpired(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t, "test") @@ -125,17 +126,17 @@ func TestPostgresORM_DeleteExpired(t *testing.T) { rows := generateTestRows(t, total) for _, row := range rows { - err := orm.Update(row) + err := orm.Update(ctx, row) assert.NoError(t, err) } - deleted, err := orm.DeleteExpired(expired, time.Now().Add(2*time.Hour).UTC()) + deleted, err := orm.DeleteExpired(ctx, expired, time.Now().Add(2*time.Hour).UTC()) assert.NoError(t, err) assert.Equal(t, int64(expired), deleted) count := 0 for _, row := range rows { - _, err := orm.Get(row.Address, row.SlotId) + _, err := orm.Get(ctx, row.Address, row.SlotId) if !errors.Is(err, s4.ErrNotFound) { count++ } @@ -149,21 +150,23 @@ func TestPostgresORM_GetSnapshot(t *testing.T) { orm := setupORM(t, "test") t.Run("no rows", func(t *testing.T) { - rows, err := orm.GetSnapshot(s4.NewFullAddressRange()) + ctx := testutils.Context(t) + rows, err := orm.GetSnapshot(ctx, s4.NewFullAddressRange()) assert.NoError(t, err) assert.Empty(t, rows) }) t.Run("with rows", func(t *testing.T) { + ctx := testutils.Context(t) rows := generateTestRows(t, 100) for _, row := range rows { - err := orm.Update(row) + err := orm.Update(ctx, row) assert.NoError(t, err) } t.Run("full range", func(t *testing.T) { - snapshot, err := orm.GetSnapshot(s4.NewFullAddressRange()) + snapshot, err := orm.GetSnapshot(testutils.Context(t), s4.NewFullAddressRange()) assert.NoError(t, err) assert.Equal(t, len(rows), len(snapshot)) @@ -188,7 +191,7 @@ func TestPostgresORM_GetSnapshot(t *testing.T) { t.Run("half range", func(t *testing.T) { ar, err := s4.NewInitialAddressRangeForIntervals(2) assert.NoError(t, err) - snapshot, err := orm.GetSnapshot(ar) + snapshot, err := orm.GetSnapshot(testutils.Context(t), ar) assert.NoError(t, err) for _, sr := range snapshot { assert.True(t, ar.Contains(sr.Address)) @@ -203,21 +206,23 @@ func TestPostgresORM_GetUnconfirmedRows(t *testing.T) { orm := setupORM(t, "test") t.Run("no rows", func(t *testing.T) { - rows, err := orm.GetUnconfirmedRows(5) + ctx := testutils.Context(t) + rows, err := orm.GetUnconfirmedRows(ctx, 5) assert.NoError(t, err) assert.Empty(t, rows) }) t.Run("with rows", func(t *testing.T) { + ctx := testutils.Context(t) rows := generateTestRows(t, 10) for _, row := range rows { - err := orm.Update(row) + err := orm.Update(ctx, row) assert.NoError(t, err) time.Sleep(testutils.TestInterval / 10) } - gotRows, err := orm.GetUnconfirmedRows(5) + gotRows, err := orm.GetUnconfirmedRows(ctx, 5) assert.NoError(t, err) assert.Len(t, gotRows, 5) @@ -229,6 +234,7 @@ func TestPostgresORM_GetUnconfirmedRows(t *testing.T) { func TestPostgresORM_Namespace(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) ormA := setupORM(t, "a") ormB := setupORM(t, "b") @@ -237,44 +243,45 @@ func TestPostgresORM_Namespace(t *testing.T) { rowsA := generateTestRows(t, n) rowsB := generateTestRows(t, n) for i := 0; i < n; i++ { - err := ormA.Update(rowsA[i]) + err := ormA.Update(ctx, rowsA[i]) assert.NoError(t, err) - err = ormB.Update(rowsB[i]) + err = ormB.Update(ctx, rowsB[i]) assert.NoError(t, err) } - urowsA, err := ormA.GetUnconfirmedRows(n) + urowsA, err := ormA.GetUnconfirmedRows(ctx, n) assert.NoError(t, err) assert.Len(t, urowsA, n/2) - urowsB, err := ormB.GetUnconfirmedRows(n) + urowsB, err := ormB.GetUnconfirmedRows(ctx, n) assert.NoError(t, err) assert.Len(t, urowsB, n/2) - _, err = ormB.DeleteExpired(n, time.Now().UTC()) + _, err = ormB.DeleteExpired(ctx, n, time.Now().UTC()) assert.NoError(t, err) - snapshotA, err := ormA.GetSnapshot(s4.NewFullAddressRange()) + snapshotA, err := ormA.GetSnapshot(ctx, s4.NewFullAddressRange()) assert.NoError(t, err) assert.Len(t, snapshotA, n) } func TestPostgresORM_BigIntVersion(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) orm := setupORM(t, "test") row := generateTestRows(t, 1)[0] row.Version = math.MaxUint64 - 10 - err := orm.Update(row) + err := orm.Update(ctx, row) assert.NoError(t, err) row.Version++ - err = orm.Update(row) + err = orm.Update(ctx, row) assert.NoError(t, err) - gotRow, err := orm.Get(row.Address, row.SlotId) + gotRow, err := orm.Get(ctx, row.Address, row.SlotId) assert.NoError(t, err) assert.Equal(t, row, gotRow) } diff --git a/core/services/s4/storage.go b/core/services/s4/storage.go index 02ba9c7bd50..1af14ec269f 100644 --- a/core/services/s4/storage.go +++ b/core/services/s4/storage.go @@ -5,11 +5,10 @@ import ( "github.com/jonboulle/clockwork" + "github.com/ethereum/go-ethereum/common" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" - - "github.com/ethereum/go-ethereum/common" ) // Constraints specifies the global storage constraints. @@ -95,7 +94,7 @@ func (s *storage) Get(ctx context.Context, key *Key) (*Record, *Metadata, error) } bigAddress := big.New(key.Address.Big()) - row, err := s.orm.Get(bigAddress, key.SlotId, pg.WithParentCtx(ctx)) + row, err := s.orm.Get(ctx, bigAddress, key.SlotId) if err != nil { return nil, nil, err } @@ -125,7 +124,7 @@ func (s *storage) List(ctx context.Context, address common.Address) ([]*Snapshot if err != nil { return nil, err } - return s.orm.GetSnapshot(sar, pg.WithParentCtx(ctx)) + return s.orm.GetSnapshot(ctx, sar) } func (s *storage) Put(ctx context.Context, key *Key, record *Record, signature []byte) error { @@ -161,5 +160,5 @@ func (s *storage) Put(ctx context.Context, key *Key, record *Record, signature [ copy(row.Payload, record.Payload) copy(row.Signature, signature) - return s.orm.Update(row, pg.WithParentCtx(ctx)) + return s.orm.Update(ctx, row) } diff --git a/core/services/s4/storage_test.go b/core/services/s4/storage_test.go index b643609f449..8deb23bb979 100644 --- a/core/services/s4/storage_test.go +++ b/core/services/s4/storage_test.go @@ -53,7 +53,7 @@ func TestStorage_Errors(t *testing.T) { SlotId: 1, Version: 0, } - ormMock.On("Get", big.New(key.Address.Big()), key.SlotId, mock.Anything).Return(nil, s4.ErrNotFound) + ormMock.On("Get", mock.Anything, big.New(key.Address.Big()), key.SlotId).Return(nil, s4.ErrNotFound) _, _, err := storage.Get(testutils.Context(t), key) assert.ErrorIs(t, err, s4.ErrNotFound) }) @@ -181,7 +181,7 @@ func TestStorage_PutAndGet(t *testing.T) { assert.NoError(t, err) ormMock.On("Update", mock.Anything, mock.Anything).Return(nil) - ormMock.On("Get", big.New(key.Address.Big()), uint(2), mock.Anything).Return(&s4.Row{ + ormMock.On("Get", mock.Anything, big.New(key.Address.Big()), uint(2)).Return(&s4.Row{ Address: big.New(key.Address.Big()), SlotId: key.SlotId, Version: key.Version, @@ -221,7 +221,7 @@ func TestStorage_List(t *testing.T) { addressRange, err := s4.NewSingleAddressRange(big.New(address.Big())) assert.NoError(t, err) - ormMock.On("GetSnapshot", addressRange, mock.Anything).Return(ormRows, nil) + ormMock.On("GetSnapshot", mock.Anything, addressRange).Return(ormRows, nil) rows, err := storage.List(testutils.Context(t), address) require.NoError(t, err) From 6dd1ef792d8e14c0976b935c8a24dc1d4ac5fd9d Mon Sep 17 00:00:00 2001 From: Erik Burton Date: Tue, 16 Apr 2024 17:27:03 -0700 Subject: [PATCH 02/19] chore: update smartcontractkit/chainlink-github-actions/github-app-token-issuer to v2.3.12 (#12853) --- .github/workflows/helm-chart-publish.yml | 2 +- .github/workflows/operator-ui-ci.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/helm-chart-publish.yml b/.github/workflows/helm-chart-publish.yml index 27c2b3c310a..927ed4b0166 100644 --- a/.github/workflows/helm-chart-publish.yml +++ b/.github/workflows/helm-chart-publish.yml @@ -25,7 +25,7 @@ jobs: - name: Get Github Token id: get-gh-token - uses: smartcontractkit/chainlink-github-actions/github-app-token-issuer@5dd916d08c03cb5f9a97304f4f174820421bb946 # v2.3.11 + uses: smartcontractkit/chainlink-github-actions/github-app-token-issuer@5874ff7211cf5a5a2670bb010fbff914eaaae138 # v2.3.12 with: url: ${{ secrets.GATI_LAMBDA_FUNCTION_URL }} diff --git a/.github/workflows/operator-ui-ci.yml b/.github/workflows/operator-ui-ci.yml index 22491c20029..b67eb2c35f7 100644 --- a/.github/workflows/operator-ui-ci.yml +++ b/.github/workflows/operator-ui-ci.yml @@ -38,7 +38,7 @@ jobs: - name: Get Github Token id: get-gh-token - uses: smartcontractkit/chainlink-github-actions/github-app-token-issuer@5dd916d08c03cb5f9a97304f4f174820421bb946 # v2.3.11 + uses: smartcontractkit/chainlink-github-actions/github-app-token-issuer@5874ff7211cf5a5a2670bb010fbff914eaaae138 # v2.3.12 with: url: ${{ secrets.AWS_INFRA_RELENG_TOKEN_ISSUER_LAMBDA_URL }} From af2a30d00e6d519c545dad806db41422266beb81 Mon Sep 17 00:00:00 2001 From: Erik Burton Date: Tue, 16 Apr 2024 18:04:43 -0700 Subject: [PATCH 03/19] fix: trim role-session-name to less than 64 characters (#12855) --- .github/workflows/goreleaser-build-publish-develop.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/goreleaser-build-publish-develop.yml b/.github/workflows/goreleaser-build-publish-develop.yml index 7c8cf07e8a1..b7fe09f35b2 100644 --- a/.github/workflows/goreleaser-build-publish-develop.yml +++ b/.github/workflows/goreleaser-build-publish-develop.yml @@ -26,7 +26,7 @@ jobs: role-duration-seconds: ${{ secrets.AWS_ROLE_DURATION_SECONDS }} aws-region: ${{ secrets.AWS_REGION }} mask-aws-account-id: true - role-session-name: goreleaser-build-publish-chainlink.push-chainlink-develop-goreleaser + role-session-name: goreleaser-build-publish-chainlink.push-develop - name: Build, sign, and publish image id: build-sign-publish uses: ./.github/actions/goreleaser-build-sign-publish From 707c7076da07e908919ba8fc8277c2737020015b Mon Sep 17 00:00:00 2001 From: frank zhu Date: Wed, 17 Apr 2024 00:11:56 -0700 Subject: [PATCH 04/19] update changeset tags comment formatting (#12840) * update changeset tags comment formatting * fix --- .github/workflows/changeset.yml | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/.github/workflows/changeset.yml b/.github/workflows/changeset.yml index 9ad49fe343e..0af608624b1 100644 --- a/.github/workflows/changeset.yml +++ b/.github/workflows/changeset.yml @@ -10,17 +10,17 @@ jobs: changeset: env: TAGS: | - #nops : For any feature that is NOP facing and needs to be in the official Release Notes for the release. - #added : For any new functionality added. - #changed : For any change to the existing functionality. - #removed : For any functionality/config that is removed. - #updated : For any functionality that is updated. - #deprecation_notice : For any upcoming deprecation functionality. - #breaking_change : For any functionality that requires manual action for the node to boot. - #db_update : For any feature that introduces updates to database schema. - #wip : For any change that is not ready yet and external communication about it should be held off till it is feature complete. - #bugfix - For bug fixes. - #internal - For changesets that need to be excluded from the final changelog. + - `#added` For any new functionality added. + - `#breaking_change` For any functionality that requires manual action for the node to boot. + - `#bugfix` For bug fixes. + - `#changed` For any change to the existing functionality. + - `#db_update` For any feature that introduces updates to database schema. + - `#deprecation_notice` For any upcoming deprecation functionality. + - `#internal` For changesets that need to be excluded from the final changelog. + - `#nops` For any feature that is NOP facing and needs to be in the official Release Notes for the release. + - `#removed` For any functionality/config that is removed. + - `#updated` For any functionality that is updated. + - `#wip` For any change that is not ready yet and external communication about it should be held off till it is feature complete. # For security reasons, GITHUB_TOKEN is read-only on forks, so we cannot leave comments on PRs. # This check skips the job if it is detected we are running on a fork. @@ -66,9 +66,7 @@ jobs: with: message: | I see you updated files related to `core`. Please run `pnpm changeset` in the root directory to add a changeset as well as in the text include at least one of the following tags: - ``` ${{ env.TAGS }} - ``` reactions: eyes comment_tag: changeset-core @@ -111,9 +109,7 @@ jobs: with: message: | I see you added a changeset file but it does not contain a tag. Please edit the text include at least one of the following tags: - ``` ${{ env.TAGS }} - ``` reactions: eyes comment_tag: changeset-core-tags From 6a0b4a9b099663e3aed202f48f363afc4d111293 Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Wed, 17 Apr 2024 06:27:16 -0500 Subject: [PATCH 05/19] core/services/relay/evm/mercury: switch to sqlutil.DataSource (#12818) --- .changeset/pretty-flies-fold.md | 5 + core/services/chainlink/application.go | 4 +- .../ocr2/plugins/mercury/plugin_test.go | 3 +- core/services/relay/evm/evm.go | 2 +- core/services/relay/evm/mercury/orm.go | 52 ++++------ core/services/relay/evm/mercury/orm_test.go | 95 +++++++++---------- .../relay/evm/mercury/persistence_manager.go | 26 ++--- .../evm/mercury/persistence_manager_test.go | 2 +- .../relay/evm/mercury/transmitter_test.go | 10 +- .../services/relay/evm/mercury/types/types.go | 4 +- .../relay/evm/mercury/v1/data_source_test.go | 3 +- .../relay/evm/mercury/v2/data_source_test.go | 3 +- .../relay/evm/mercury/v3/data_source_test.go | 3 +- 13 files changed, 102 insertions(+), 110 deletions(-) create mode 100644 .changeset/pretty-flies-fold.md diff --git a/.changeset/pretty-flies-fold.md b/.changeset/pretty-flies-fold.md new file mode 100644 index 00000000000..d67a3117e14 --- /dev/null +++ b/.changeset/pretty-flies-fold.md @@ -0,0 +1,5 @@ +--- +"chainlink": patch +--- + +cor/services/relay/evm/mercury: switch to sqlutil.DataStore #internal diff --git a/core/services/chainlink/application.go b/core/services/chainlink/application.go index 6c373846205..832bea523b5 100644 --- a/core/services/chainlink/application.go +++ b/core/services/chainlink/application.go @@ -310,10 +310,10 @@ func NewApplication(opts ApplicationOpts) (Application, error) { var ( pipelineORM = pipeline.NewORM(sqlxDB, globalLogger, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) bridgeORM = bridges.NewORM(sqlxDB) - mercuryORM = mercury.NewORM(sqlxDB, globalLogger, cfg.Database()) + mercuryORM = mercury.NewORM(opts.DB) pipelineRunner = pipeline.NewRunner(pipelineORM, bridgeORM, cfg.JobPipeline(), cfg.WebServer(), legacyEVMChains, keyStore.Eth(), keyStore.VRF(), globalLogger, restrictedHTTPClient, unrestrictedHTTPClient) jobORM = job.NewORM(sqlxDB, pipelineORM, bridgeORM, keyStore, globalLogger, cfg.Database()) - txmORM = txmgr.NewTxStore(sqlxDB, globalLogger) + txmORM = txmgr.NewTxStore(opts.DB, globalLogger) streamRegistry = streams.NewRegistry(globalLogger, pipelineRunner) ) diff --git a/core/services/ocr2/plugins/mercury/plugin_test.go b/core/services/ocr2/plugins/mercury/plugin_test.go index 3934105a390..131f51af4b4 100644 --- a/core/services/ocr2/plugins/mercury/plugin_test.go +++ b/core/services/ocr2/plugins/mercury/plugin_test.go @@ -26,7 +26,6 @@ import ( libocr2 "github.com/smartcontractkit/libocr/offchainreporting2plus" libocr2types "github.com/smartcontractkit/libocr/offchainreporting2plus/types" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/services/relay" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury/types" @@ -279,7 +278,7 @@ var _ plugins.RegistrarConfig = (*testRegistrarConfig)(nil) type testDataSourceORM struct{} // LatestReport implements types.DataSourceORM. -func (*testDataSourceORM) LatestReport(ctx context.Context, feedID [32]byte, qopts ...pg.QOpt) (report []byte, err error) { +func (*testDataSourceORM) LatestReport(ctx context.Context, feedID [32]byte) (report []byte, err error) { return []byte{1, 2, 3}, nil } diff --git a/core/services/relay/evm/evm.go b/core/services/relay/evm/evm.go index c8fe1b868a7..1a09e681f8a 100644 --- a/core/services/relay/evm/evm.go +++ b/core/services/relay/evm/evm.go @@ -130,7 +130,7 @@ func NewRelayer(lggr logger.Logger, chain legacyevm.Chain, opts RelayerOpts) (*R } lggr = lggr.Named("Relayer") - mercuryORM := mercury.NewORM(opts.DB, lggr, opts.QConfig) + mercuryORM := mercury.NewORM(opts.DS) lloORM := llo.NewORM(opts.DS, chain.ID()) cdcFactory := llo.NewChannelDefinitionCacheFactory(lggr, lloORM, chain.LogPoller()) return &Relayer{ diff --git a/core/services/relay/evm/mercury/orm.go b/core/services/relay/evm/mercury/orm.go index 19f2aa8e16b..6426ef54a5d 100644 --- a/core/services/relay/evm/mercury/orm.go +++ b/core/services/relay/evm/mercury/orm.go @@ -8,24 +8,22 @@ import ( "sync" "github.com/ethereum/go-ethereum/common" - "github.com/jmoiron/sqlx" "github.com/lib/pq" pkgerrors "github.com/pkg/errors" ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" - "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury/utils" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury/wsrpc/pb" ) type ORM interface { - InsertTransmitRequest(serverURL string, req *pb.TransmitRequest, jobID int32, reportCtx ocrtypes.ReportContext, qopts ...pg.QOpt) error - DeleteTransmitRequests(serverURL string, reqs []*pb.TransmitRequest, qopts ...pg.QOpt) error - GetTransmitRequests(serverURL string, jobID int32, qopts ...pg.QOpt) ([]*Transmission, error) - PruneTransmitRequests(serverURL string, jobID int32, maxSize int, qopts ...pg.QOpt) error - LatestReport(ctx context.Context, feedID [32]byte, qopts ...pg.QOpt) (report []byte, err error) + InsertTransmitRequest(ctx context.Context, serverURL string, req *pb.TransmitRequest, jobID int32, reportCtx ocrtypes.ReportContext) error + DeleteTransmitRequests(ctx context.Context, serverURL string, reqs []*pb.TransmitRequest) error + GetTransmitRequests(ctx context.Context, serverURL string, jobID int32) ([]*Transmission, error) + PruneTransmitRequests(ctx context.Context, serverURL string, jobID int32, maxSize int) error + LatestReport(ctx context.Context, feedID [32]byte) (report []byte, err error) } func FeedIDFromReport(report ocrtypes.Report) (feedID utils.FeedID, err error) { @@ -36,32 +34,27 @@ func FeedIDFromReport(report ocrtypes.Report) (feedID utils.FeedID, err error) { } type orm struct { - q pg.Q + ds sqlutil.DataSource } -func NewORM(db *sqlx.DB, lggr logger.Logger, cfg pg.QConfig) ORM { - namedLogger := lggr.Named("MercuryORM") - q := pg.NewQ(db, namedLogger, cfg) - return &orm{ - q: q, - } +func NewORM(ds sqlutil.DataSource) ORM { + return &orm{ds: ds} } // InsertTransmitRequest inserts one transmit request if the payload does not exist already. -func (o *orm) InsertTransmitRequest(serverURL string, req *pb.TransmitRequest, jobID int32, reportCtx ocrtypes.ReportContext, qopts ...pg.QOpt) error { +func (o *orm) InsertTransmitRequest(ctx context.Context, serverURL string, req *pb.TransmitRequest, jobID int32, reportCtx ocrtypes.ReportContext) error { feedID, err := FeedIDFromReport(req.Payload) if err != nil { return err } - q := o.q.WithOpts(qopts...) var wg sync.WaitGroup wg.Add(2) var err1, err2 error go func() { defer wg.Done() - err1 = q.ExecQ(` + _, err1 = o.ds.ExecContext(ctx, ` INSERT INTO mercury_transmit_requests (server_url, payload, payload_hash, config_digest, epoch, round, extra_hash, job_id, feed_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) ON CONFLICT (server_url, payload_hash) DO NOTHING @@ -70,7 +63,7 @@ func (o *orm) InsertTransmitRequest(serverURL string, req *pb.TransmitRequest, j go func() { defer wg.Done() - err2 = q.ExecQ(` + _, err2 = o.ds.ExecContext(ctx, ` INSERT INTO feed_latest_reports (feed_id, report, epoch, round, updated_at, job_id) VALUES ($1, $2, $3, $4, NOW(), $5) ON CONFLICT (feed_id) DO UPDATE @@ -83,7 +76,7 @@ func (o *orm) InsertTransmitRequest(serverURL string, req *pb.TransmitRequest, j } // DeleteTransmitRequest deletes the given transmit requests if they exist. -func (o *orm) DeleteTransmitRequests(serverURL string, reqs []*pb.TransmitRequest, qopts ...pg.QOpt) error { +func (o *orm) DeleteTransmitRequests(ctx context.Context, serverURL string, reqs []*pb.TransmitRequest) error { if len(reqs) == 0 { return nil } @@ -93,8 +86,7 @@ func (o *orm) DeleteTransmitRequests(serverURL string, reqs []*pb.TransmitReques hashes = append(hashes, hashPayload(req.Payload)) } - q := o.q.WithOpts(qopts...) - err := q.ExecQ(` + _, err := o.ds.ExecContext(ctx, ` DELETE FROM mercury_transmit_requests WHERE server_url = $1 AND payload_hash = ANY($2) `, serverURL, hashes) @@ -102,11 +94,10 @@ func (o *orm) DeleteTransmitRequests(serverURL string, reqs []*pb.TransmitReques } // GetTransmitRequests returns all transmit requests in chronologically descending order. -func (o *orm) GetTransmitRequests(serverURL string, jobID int32, qopts ...pg.QOpt) ([]*Transmission, error) { - q := o.q.WithOpts(qopts...) +func (o *orm) GetTransmitRequests(ctx context.Context, serverURL string, jobID int32) ([]*Transmission, error) { // The priority queue uses epoch and round to sort transmissions so order by // the same fields here for optimal insertion into the pq. - rows, err := q.QueryContext(q.ParentCtx, ` + rows, err := o.ds.QueryContext(ctx, ` SELECT payload, config_digest, epoch, round, extra_hash FROM mercury_transmit_requests WHERE job_id = $1 AND server_url = $2 @@ -146,10 +137,9 @@ func (o *orm) GetTransmitRequests(serverURL string, jobID int32, qopts ...pg.QOp // PruneTransmitRequests keeps at most maxSize rows for the given job ID, // deleting the oldest transactions. -func (o *orm) PruneTransmitRequests(serverURL string, jobID int32, maxSize int, qopts ...pg.QOpt) error { - q := o.q.WithOpts(qopts...) +func (o *orm) PruneTransmitRequests(ctx context.Context, serverURL string, jobID int32, maxSize int) error { // Prune the oldest requests by epoch and round. - return q.ExecQ(` + _, err := o.ds.ExecContext(ctx, ` DELETE FROM mercury_transmit_requests WHERE job_id = $1 AND server_url = $2 AND payload_hash NOT IN ( @@ -160,11 +150,11 @@ func (o *orm) PruneTransmitRequests(serverURL string, jobID int32, maxSize int, LIMIT $3 ) `, jobID, serverURL, maxSize) + return err } -func (o *orm) LatestReport(ctx context.Context, feedID [32]byte, qopts ...pg.QOpt) (report []byte, err error) { - q := o.q.WithOpts(qopts...) - err = q.GetContext(ctx, &report, `SELECT report FROM feed_latest_reports WHERE feed_id = $1`, feedID[:]) +func (o *orm) LatestReport(ctx context.Context, feedID [32]byte) (report []byte, err error) { + err = o.ds.GetContext(ctx, &report, `SELECT report FROM feed_latest_reports WHERE feed_id = $1`, feedID[:]) if errors.Is(err, sql.ErrNoRows) { return nil, nil } diff --git a/core/services/relay/evm/mercury/orm_test.go b/core/services/relay/evm/mercury/orm_test.go index 14be878eeef..2b2e15ffd53 100644 --- a/core/services/relay/evm/mercury/orm_test.go +++ b/core/services/relay/evm/mercury/orm_test.go @@ -10,7 +10,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" - "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury/wsrpc/pb" ) @@ -21,13 +20,13 @@ var ( ) func TestORM(t *testing.T) { + ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) jobID := rand.Int32() // foreign key constraints disabled so value doesn't matter pgtest.MustExec(t, db, `SET CONSTRAINTS mercury_transmit_requests_job_id_fkey DEFERRED`) pgtest.MustExec(t, db, `SET CONSTRAINTS feed_latest_reports_job_id_fkey DEFERRED`) - lggr := logger.TestLogger(t) - orm := NewORM(db, lggr, pgtest.NewQConfig(true)) + orm := NewORM(db) feedID := sampleFeedID reports := sampleReports @@ -49,25 +48,25 @@ func TestORM(t *testing.T) { // Test insert and get requests. // s1 - err = orm.InsertTransmitRequest(sURL, &pb.TransmitRequest{Payload: reports[0]}, jobID, reportContexts[0]) + err = orm.InsertTransmitRequest(ctx, sURL, &pb.TransmitRequest{Payload: reports[0]}, jobID, reportContexts[0]) require.NoError(t, err) - err = orm.InsertTransmitRequest(sURL, &pb.TransmitRequest{Payload: reports[1]}, jobID, reportContexts[1]) + err = orm.InsertTransmitRequest(ctx, sURL, &pb.TransmitRequest{Payload: reports[1]}, jobID, reportContexts[1]) require.NoError(t, err) - err = orm.InsertTransmitRequest(sURL, &pb.TransmitRequest{Payload: reports[2]}, jobID, reportContexts[2]) + err = orm.InsertTransmitRequest(ctx, sURL, &pb.TransmitRequest{Payload: reports[2]}, jobID, reportContexts[2]) require.NoError(t, err) // s2 - err = orm.InsertTransmitRequest(sURL2, &pb.TransmitRequest{Payload: reports[3]}, jobID, reportContexts[0]) + err = orm.InsertTransmitRequest(ctx, sURL2, &pb.TransmitRequest{Payload: reports[3]}, jobID, reportContexts[0]) require.NoError(t, err) - transmissions, err := orm.GetTransmitRequests(sURL, jobID) + transmissions, err := orm.GetTransmitRequests(ctx, sURL, jobID) require.NoError(t, err) require.Equal(t, transmissions, []*Transmission{ {Req: &pb.TransmitRequest{Payload: reports[2]}, ReportCtx: reportContexts[2]}, {Req: &pb.TransmitRequest{Payload: reports[1]}, ReportCtx: reportContexts[1]}, {Req: &pb.TransmitRequest{Payload: reports[0]}, ReportCtx: reportContexts[0]}, }) - transmissions, err = orm.GetTransmitRequests(sURL2, jobID) + transmissions, err = orm.GetTransmitRequests(ctx, sURL2, jobID) require.NoError(t, err) require.Equal(t, transmissions, []*Transmission{ {Req: &pb.TransmitRequest{Payload: reports[3]}, ReportCtx: reportContexts[0]}, @@ -79,10 +78,10 @@ func TestORM(t *testing.T) { assert.Equal(t, reports[2], l) // Test requests can be deleted. - err = orm.DeleteTransmitRequests(sURL, []*pb.TransmitRequest{{Payload: reports[1]}}) + err = orm.DeleteTransmitRequests(ctx, sURL, []*pb.TransmitRequest{{Payload: reports[1]}}) require.NoError(t, err) - transmissions, err = orm.GetTransmitRequests(sURL, jobID) + transmissions, err = orm.GetTransmitRequests(ctx, sURL, jobID) require.NoError(t, err) require.Equal(t, transmissions, []*Transmission{ {Req: &pb.TransmitRequest{Payload: reports[2]}, ReportCtx: reportContexts[2]}, @@ -94,10 +93,10 @@ func TestORM(t *testing.T) { assert.Equal(t, reports[2], l) // Test deleting non-existent requests does not error. - err = orm.DeleteTransmitRequests(sURL, []*pb.TransmitRequest{{Payload: []byte("does-not-exist")}}) + err = orm.DeleteTransmitRequests(ctx, sURL, []*pb.TransmitRequest{{Payload: []byte("does-not-exist")}}) require.NoError(t, err) - transmissions, err = orm.GetTransmitRequests(sURL, jobID) + transmissions, err = orm.GetTransmitRequests(ctx, sURL, jobID) require.NoError(t, err) require.Equal(t, transmissions, []*Transmission{ {Req: &pb.TransmitRequest{Payload: reports[2]}, ReportCtx: reportContexts[2]}, @@ -105,7 +104,7 @@ func TestORM(t *testing.T) { }) // Test deleting multiple requests. - err = orm.DeleteTransmitRequests(sURL, []*pb.TransmitRequest{ + err = orm.DeleteTransmitRequests(ctx, sURL, []*pb.TransmitRequest{ {Payload: reports[0]}, {Payload: reports[2]}, }) @@ -115,27 +114,27 @@ func TestORM(t *testing.T) { require.NoError(t, err) assert.Equal(t, reports[2], l) - transmissions, err = orm.GetTransmitRequests(sURL, jobID) + transmissions, err = orm.GetTransmitRequests(ctx, sURL, jobID) require.NoError(t, err) require.Empty(t, transmissions) // More inserts. - err = orm.InsertTransmitRequest(sURL, &pb.TransmitRequest{Payload: reports[3]}, jobID, reportContexts[3]) + err = orm.InsertTransmitRequest(ctx, sURL, &pb.TransmitRequest{Payload: reports[3]}, jobID, reportContexts[3]) require.NoError(t, err) - transmissions, err = orm.GetTransmitRequests(sURL, jobID) + transmissions, err = orm.GetTransmitRequests(ctx, sURL, jobID) require.NoError(t, err) require.Equal(t, transmissions, []*Transmission{ {Req: &pb.TransmitRequest{Payload: reports[3]}, ReportCtx: reportContexts[3]}, }) // Duplicate requests are ignored. - err = orm.InsertTransmitRequest(sURL, &pb.TransmitRequest{Payload: reports[3]}, jobID, reportContexts[3]) + err = orm.InsertTransmitRequest(ctx, sURL, &pb.TransmitRequest{Payload: reports[3]}, jobID, reportContexts[3]) require.NoError(t, err) - err = orm.InsertTransmitRequest(sURL, &pb.TransmitRequest{Payload: reports[3]}, jobID, reportContexts[3]) + err = orm.InsertTransmitRequest(ctx, sURL, &pb.TransmitRequest{Payload: reports[3]}, jobID, reportContexts[3]) require.NoError(t, err) - transmissions, err = orm.GetTransmitRequests(sURL, jobID) + transmissions, err = orm.GetTransmitRequests(ctx, sURL, jobID) require.NoError(t, err) require.Equal(t, transmissions, []*Transmission{ {Req: &pb.TransmitRequest{Payload: reports[3]}, ReportCtx: reportContexts[3]}, @@ -146,20 +145,20 @@ func TestORM(t *testing.T) { assert.Equal(t, reports[3], l) // s2 not affected by deletion - transmissions, err = orm.GetTransmitRequests(sURL2, jobID) + transmissions, err = orm.GetTransmitRequests(ctx, sURL2, jobID) require.NoError(t, err) require.Len(t, transmissions, 1) } func TestORM_PruneTransmitRequests(t *testing.T) { + ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) jobID := rand.Int32() // foreign key constraints disabled so value doesn't matter pgtest.MustExec(t, db, `SET CONSTRAINTS mercury_transmit_requests_job_id_fkey DEFERRED`) pgtest.MustExec(t, db, `SET CONSTRAINTS feed_latest_reports_job_id_fkey DEFERRED`) - lggr := logger.TestLogger(t) - orm := NewORM(db, lggr, pgtest.NewQConfig(true)) + orm := NewORM(db) reports := sampleReports @@ -175,25 +174,25 @@ func TestORM_PruneTransmitRequests(t *testing.T) { } // s1 - err := orm.InsertTransmitRequest(sURL, &pb.TransmitRequest{Payload: reports[0]}, jobID, makeReportContext(1, 1)) + err := orm.InsertTransmitRequest(ctx, sURL, &pb.TransmitRequest{Payload: reports[0]}, jobID, makeReportContext(1, 1)) require.NoError(t, err) - err = orm.InsertTransmitRequest(sURL, &pb.TransmitRequest{Payload: reports[1]}, jobID, makeReportContext(1, 2)) + err = orm.InsertTransmitRequest(ctx, sURL, &pb.TransmitRequest{Payload: reports[1]}, jobID, makeReportContext(1, 2)) require.NoError(t, err) // s2 - should not be touched - err = orm.InsertTransmitRequest(sURL2, &pb.TransmitRequest{Payload: reports[0]}, jobID, makeReportContext(1, 0)) + err = orm.InsertTransmitRequest(ctx, sURL2, &pb.TransmitRequest{Payload: reports[0]}, jobID, makeReportContext(1, 0)) require.NoError(t, err) - err = orm.InsertTransmitRequest(sURL2, &pb.TransmitRequest{Payload: reports[0]}, jobID, makeReportContext(1, 1)) + err = orm.InsertTransmitRequest(ctx, sURL2, &pb.TransmitRequest{Payload: reports[0]}, jobID, makeReportContext(1, 1)) require.NoError(t, err) - err = orm.InsertTransmitRequest(sURL2, &pb.TransmitRequest{Payload: reports[1]}, jobID, makeReportContext(1, 2)) + err = orm.InsertTransmitRequest(ctx, sURL2, &pb.TransmitRequest{Payload: reports[1]}, jobID, makeReportContext(1, 2)) require.NoError(t, err) - err = orm.InsertTransmitRequest(sURL2, &pb.TransmitRequest{Payload: reports[2]}, jobID, makeReportContext(1, 3)) + err = orm.InsertTransmitRequest(ctx, sURL2, &pb.TransmitRequest{Payload: reports[2]}, jobID, makeReportContext(1, 3)) require.NoError(t, err) // Max size greater than number of records, expect no-op - err = orm.PruneTransmitRequests(sURL, jobID, 5) + err = orm.PruneTransmitRequests(ctx, sURL, jobID, 5) require.NoError(t, err) - transmissions, err := orm.GetTransmitRequests(sURL, jobID) + transmissions, err := orm.GetTransmitRequests(ctx, sURL, jobID) require.NoError(t, err) require.Equal(t, transmissions, []*Transmission{ {Req: &pb.TransmitRequest{Payload: reports[1]}, ReportCtx: makeReportContext(1, 2)}, @@ -201,10 +200,10 @@ func TestORM_PruneTransmitRequests(t *testing.T) { }) // Max size equal to number of records, expect no-op - err = orm.PruneTransmitRequests(sURL, jobID, 2) + err = orm.PruneTransmitRequests(ctx, sURL, jobID, 2) require.NoError(t, err) - transmissions, err = orm.GetTransmitRequests(sURL, jobID) + transmissions, err = orm.GetTransmitRequests(ctx, sURL, jobID) require.NoError(t, err) require.Equal(t, transmissions, []*Transmission{ {Req: &pb.TransmitRequest{Payload: reports[1]}, ReportCtx: makeReportContext(1, 2)}, @@ -212,26 +211,26 @@ func TestORM_PruneTransmitRequests(t *testing.T) { }) // Max size is number of records + 1, but jobID differs, expect no-op - err = orm.PruneTransmitRequests(sURL, -1, 2) + err = orm.PruneTransmitRequests(ctx, sURL, -1, 2) require.NoError(t, err) - transmissions, err = orm.GetTransmitRequests(sURL, jobID) + transmissions, err = orm.GetTransmitRequests(ctx, sURL, jobID) require.NoError(t, err) require.Equal(t, []*Transmission{ {Req: &pb.TransmitRequest{Payload: reports[1]}, ReportCtx: makeReportContext(1, 2)}, {Req: &pb.TransmitRequest{Payload: reports[0]}, ReportCtx: makeReportContext(1, 1)}, }, transmissions) - err = orm.InsertTransmitRequest(sURL, &pb.TransmitRequest{Payload: reports[2]}, jobID, makeReportContext(2, 1)) + err = orm.InsertTransmitRequest(ctx, sURL, &pb.TransmitRequest{Payload: reports[2]}, jobID, makeReportContext(2, 1)) require.NoError(t, err) - err = orm.InsertTransmitRequest(sURL, &pb.TransmitRequest{Payload: reports[3]}, jobID, makeReportContext(2, 2)) + err = orm.InsertTransmitRequest(ctx, sURL, &pb.TransmitRequest{Payload: reports[3]}, jobID, makeReportContext(2, 2)) require.NoError(t, err) // Max size is table size - 1, expect the oldest row to be pruned. - err = orm.PruneTransmitRequests(sURL, jobID, 3) + err = orm.PruneTransmitRequests(ctx, sURL, jobID, 3) require.NoError(t, err) - transmissions, err = orm.GetTransmitRequests(sURL, jobID) + transmissions, err = orm.GetTransmitRequests(ctx, sURL, jobID) require.NoError(t, err) require.Equal(t, []*Transmission{ {Req: &pb.TransmitRequest{Payload: reports[3]}, ReportCtx: makeReportContext(2, 2)}, @@ -240,19 +239,19 @@ func TestORM_PruneTransmitRequests(t *testing.T) { }, transmissions) // s2 not touched - transmissions, err = orm.GetTransmitRequests(sURL2, jobID) + transmissions, err = orm.GetTransmitRequests(ctx, sURL2, jobID) require.NoError(t, err) assert.Len(t, transmissions, 3) } func TestORM_InsertTransmitRequest_LatestReport(t *testing.T) { + ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) jobID := rand.Int32() // foreign key constraints disabled so value doesn't matter pgtest.MustExec(t, db, `SET CONSTRAINTS mercury_transmit_requests_job_id_fkey DEFERRED`) pgtest.MustExec(t, db, `SET CONSTRAINTS feed_latest_reports_job_id_fkey DEFERRED`) - lggr := logger.TestLogger(t) - orm := NewORM(db, lggr, pgtest.NewQConfig(true)) + orm := NewORM(db) feedID := sampleFeedID reports := sampleReports @@ -268,13 +267,13 @@ func TestORM_InsertTransmitRequest_LatestReport(t *testing.T) { } } - err := orm.InsertTransmitRequest(sURL, &pb.TransmitRequest{Payload: reports[0]}, jobID, makeReportContext( + err := orm.InsertTransmitRequest(ctx, sURL, &pb.TransmitRequest{Payload: reports[0]}, jobID, makeReportContext( 0, 0, )) require.NoError(t, err) // this should be ignored, because report context is the same - err = orm.InsertTransmitRequest(sURL2, &pb.TransmitRequest{Payload: reports[1]}, jobID, makeReportContext( + err = orm.InsertTransmitRequest(ctx, sURL2, &pb.TransmitRequest{Payload: reports[1]}, jobID, makeReportContext( 0, 0, )) require.NoError(t, err) @@ -284,7 +283,7 @@ func TestORM_InsertTransmitRequest_LatestReport(t *testing.T) { assert.Equal(t, reports[0], l) t.Run("replaces if epoch and round are larger", func(t *testing.T) { - err = orm.InsertTransmitRequest("foo", &pb.TransmitRequest{Payload: reports[1]}, jobID, makeReportContext(1, 1)) + err = orm.InsertTransmitRequest(ctx, "foo", &pb.TransmitRequest{Payload: reports[1]}, jobID, makeReportContext(1, 1)) require.NoError(t, err) l, err = orm.LatestReport(testutils.Context(t), feedID) @@ -292,7 +291,7 @@ func TestORM_InsertTransmitRequest_LatestReport(t *testing.T) { assert.Equal(t, reports[1], l) }) t.Run("replaces if epoch is the same but round is greater", func(t *testing.T) { - err = orm.InsertTransmitRequest(sURL, &pb.TransmitRequest{Payload: reports[2]}, jobID, makeReportContext(1, 2)) + err = orm.InsertTransmitRequest(ctx, sURL, &pb.TransmitRequest{Payload: reports[2]}, jobID, makeReportContext(1, 2)) require.NoError(t, err) l, err = orm.LatestReport(testutils.Context(t), feedID) @@ -300,7 +299,7 @@ func TestORM_InsertTransmitRequest_LatestReport(t *testing.T) { assert.Equal(t, reports[2], l) }) t.Run("replaces if epoch is larger but round is smaller", func(t *testing.T) { - err = orm.InsertTransmitRequest("bar", &pb.TransmitRequest{Payload: reports[3]}, jobID, makeReportContext(2, 1)) + err = orm.InsertTransmitRequest(ctx, "bar", &pb.TransmitRequest{Payload: reports[3]}, jobID, makeReportContext(2, 1)) require.NoError(t, err) l, err = orm.LatestReport(testutils.Context(t), feedID) @@ -308,7 +307,7 @@ func TestORM_InsertTransmitRequest_LatestReport(t *testing.T) { assert.Equal(t, reports[3], l) }) t.Run("does not overwrite if epoch/round is the same", func(t *testing.T) { - err = orm.InsertTransmitRequest(sURL, &pb.TransmitRequest{Payload: reports[0]}, jobID, makeReportContext(2, 1)) + err = orm.InsertTransmitRequest(ctx, sURL, &pb.TransmitRequest{Payload: reports[0]}, jobID, makeReportContext(2, 1)) require.NoError(t, err) l, err = orm.LatestReport(testutils.Context(t), feedID) diff --git a/core/services/relay/evm/mercury/persistence_manager.go b/core/services/relay/evm/mercury/persistence_manager.go index dc805c12e7b..d49d0d4ed01 100644 --- a/core/services/relay/evm/mercury/persistence_manager.go +++ b/core/services/relay/evm/mercury/persistence_manager.go @@ -8,8 +8,8 @@ import ( ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury/wsrpc/pb" "github.com/smartcontractkit/chainlink/v2/core/utils" ) @@ -69,11 +69,11 @@ func (pm *PersistenceManager) Close() error { } func (pm *PersistenceManager) Insert(ctx context.Context, req *pb.TransmitRequest, reportCtx ocrtypes.ReportContext) error { - return pm.orm.InsertTransmitRequest(pm.serverURL, req, pm.jobID, reportCtx, pg.WithParentCtx(ctx)) + return pm.orm.InsertTransmitRequest(ctx, pm.serverURL, req, pm.jobID, reportCtx) } func (pm *PersistenceManager) Delete(ctx context.Context, req *pb.TransmitRequest) error { - return pm.orm.DeleteTransmitRequests(pm.serverURL, []*pb.TransmitRequest{req}, pg.WithParentCtx(ctx)) + return pm.orm.DeleteTransmitRequests(ctx, pm.serverURL, []*pb.TransmitRequest{req}) } func (pm *PersistenceManager) AsyncDelete(req *pb.TransmitRequest) { @@ -81,7 +81,7 @@ func (pm *PersistenceManager) AsyncDelete(req *pb.TransmitRequest) { } func (pm *PersistenceManager) Load(ctx context.Context) ([]*Transmission, error) { - return pm.orm.GetTransmitRequests(pm.serverURL, pm.jobID, pg.WithParentCtx(ctx)) + return pm.orm.GetTransmitRequests(ctx, pm.serverURL, pm.jobID) } func (pm *PersistenceManager) runFlushDeletesLoop() { @@ -98,7 +98,7 @@ func (pm *PersistenceManager) runFlushDeletesLoop() { return case <-ticker.C: queuedReqs := pm.resetDeleteQueue() - if err := pm.orm.DeleteTransmitRequests(pm.serverURL, queuedReqs, pg.WithParentCtx(ctx)); err != nil { + if err := pm.orm.DeleteTransmitRequests(ctx, pm.serverURL, queuedReqs); err != nil { pm.lggr.Errorw("Failed to delete queued transmit requests", "err", err) pm.addToDeleteQueue(queuedReqs...) } else { @@ -111,7 +111,7 @@ func (pm *PersistenceManager) runFlushDeletesLoop() { func (pm *PersistenceManager) runPruneLoop() { defer pm.wg.Done() - ctx, cancel := pm.stopCh.Ctx(context.Background()) + ctx, cancel := pm.stopCh.NewCtx() defer cancel() ticker := time.NewTicker(utils.WithJitter(pm.pruneFrequency)) @@ -121,11 +121,15 @@ func (pm *PersistenceManager) runPruneLoop() { ticker.Stop() return case <-ticker.C: - if err := pm.orm.PruneTransmitRequests(pm.serverURL, pm.jobID, pm.maxTransmitQueueSize, pg.WithParentCtx(ctx), pg.WithLongQueryTimeout()); err != nil { - pm.lggr.Errorw("Failed to prune transmit requests table", "err", err) - } else { - pm.lggr.Debugw("Pruned transmit requests table") - } + func(ctx context.Context) { + ctx, cancelPrune := context.WithTimeout(sqlutil.WithoutDefaultTimeout(ctx), time.Minute) + defer cancelPrune() + if err := pm.orm.PruneTransmitRequests(ctx, pm.serverURL, pm.jobID, pm.maxTransmitQueueSize); err != nil { + pm.lggr.Errorw("Failed to prune transmit requests table", "err", err) + } else { + pm.lggr.Debugw("Pruned transmit requests table") + } + }(ctx) } } } diff --git a/core/services/relay/evm/mercury/persistence_manager_test.go b/core/services/relay/evm/mercury/persistence_manager_test.go index 15b1424f1a4..1ba999614a6 100644 --- a/core/services/relay/evm/mercury/persistence_manager_test.go +++ b/core/services/relay/evm/mercury/persistence_manager_test.go @@ -22,7 +22,7 @@ import ( func bootstrapPersistenceManager(t *testing.T, jobID int32, db *sqlx.DB) (*PersistenceManager, *observer.ObservedLogs) { t.Helper() lggr, observedLogs := logger.TestLoggerObserved(t, zapcore.DebugLevel) - orm := NewORM(db, lggr, pgtest.NewQConfig(true)) + orm := NewORM(db) return NewPersistenceManager(lggr, "mercuryserver.example", orm, jobID, 2, 5*time.Millisecond, 5*time.Millisecond), observedLogs } diff --git a/core/services/relay/evm/mercury/transmitter_test.go b/core/services/relay/evm/mercury/transmitter_test.go index d7d62a9f422..46bf116ed3a 100644 --- a/core/services/relay/evm/mercury/transmitter_test.go +++ b/core/services/relay/evm/mercury/transmitter_test.go @@ -17,7 +17,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" mercurytypes "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury/types" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury/wsrpc" - mocks "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury/wsrpc/mocks" + "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury/wsrpc/mocks" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury/wsrpc/pb" ) @@ -28,7 +28,7 @@ func Test_MercuryTransmitter_Transmit(t *testing.T) { pgtest.MustExec(t, db, `SET CONSTRAINTS mercury_transmit_requests_job_id_fkey DEFERRED`) pgtest.MustExec(t, db, `SET CONSTRAINTS feed_latest_reports_job_id_fkey DEFERRED`) codec := new(mockCodec) - orm := NewORM(db, lggr, pgtest.NewQConfig(true)) + orm := NewORM(db) clients := map[string]wsrpc.Client{} t.Run("with one mercury server", func(t *testing.T) { @@ -109,7 +109,7 @@ func Test_MercuryTransmitter_LatestTimestamp(t *testing.T) { var jobID int32 codec := new(mockCodec) - orm := NewORM(db, lggr, pgtest.NewQConfig(true)) + orm := NewORM(db) clients := map[string]wsrpc.Client{} t.Run("successful query", func(t *testing.T) { @@ -211,7 +211,7 @@ func Test_MercuryTransmitter_LatestPrice(t *testing.T) { var jobID int32 codec := new(mockCodec) - orm := NewORM(db, lggr, pgtest.NewQConfig(true)) + orm := NewORM(db) clients := map[string]wsrpc.Client{} t.Run("successful query", func(t *testing.T) { @@ -287,7 +287,7 @@ func Test_MercuryTransmitter_FetchInitialMaxFinalizedBlockNumber(t *testing.T) { db := pgtest.NewSqlxDB(t) var jobID int32 codec := new(mockCodec) - orm := NewORM(db, lggr, pgtest.NewQConfig(true)) + orm := NewORM(db) clients := map[string]wsrpc.Client{} t.Run("successful query", func(t *testing.T) { diff --git a/core/services/relay/evm/mercury/types/types.go b/core/services/relay/evm/mercury/types/types.go index 49bffb6c290..972367940b5 100644 --- a/core/services/relay/evm/mercury/types/types.go +++ b/core/services/relay/evm/mercury/types/types.go @@ -7,12 +7,10 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" - - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) type DataSourceORM interface { - LatestReport(ctx context.Context, feedID [32]byte, qopts ...pg.QOpt) (report []byte, err error) + LatestReport(ctx context.Context, feedID [32]byte) (report []byte, err error) } type ReportCodec interface { diff --git a/core/services/relay/evm/mercury/v1/data_source_test.go b/core/services/relay/evm/mercury/v1/data_source_test.go index bd0f803cada..197d802a3b3 100644 --- a/core/services/relay/evm/mercury/v1/data_source_test.go +++ b/core/services/relay/evm/mercury/v1/data_source_test.go @@ -25,7 +25,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm" mercurymocks "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury/mocks" @@ -65,7 +64,7 @@ type mockORM struct { err error } -func (m *mockORM) LatestReport(ctx context.Context, feedID [32]byte, qopts ...pg.QOpt) (report []byte, err error) { +func (m *mockORM) LatestReport(ctx context.Context, feedID [32]byte) (report []byte, err error) { return m.report, m.err } diff --git a/core/services/relay/evm/mercury/v2/data_source_test.go b/core/services/relay/evm/mercury/v2/data_source_test.go index c9ae37ae018..19af909c8e9 100644 --- a/core/services/relay/evm/mercury/v2/data_source_test.go +++ b/core/services/relay/evm/mercury/v2/data_source_test.go @@ -15,7 +15,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" mercurymocks "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury/mocks" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury/utils" @@ -59,7 +58,7 @@ type mockORM struct { err error } -func (m *mockORM) LatestReport(ctx context.Context, feedID [32]byte, qopts ...pg.QOpt) (report []byte, err error) { +func (m *mockORM) LatestReport(ctx context.Context, feedID [32]byte) (report []byte, err error) { return m.report, m.err } diff --git a/core/services/relay/evm/mercury/v3/data_source_test.go b/core/services/relay/evm/mercury/v3/data_source_test.go index 4ff713abb21..ffcdc28f81c 100644 --- a/core/services/relay/evm/mercury/v3/data_source_test.go +++ b/core/services/relay/evm/mercury/v3/data_source_test.go @@ -15,7 +15,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" mercurymocks "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury/mocks" "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/mercury/utils" @@ -59,7 +58,7 @@ type mockORM struct { err error } -func (m *mockORM) LatestReport(ctx context.Context, feedID [32]byte, qopts ...pg.QOpt) (report []byte, err error) { +func (m *mockORM) LatestReport(ctx context.Context, feedID [32]byte) (report []byte, err error) { return m.report, m.err } From f50f2dd2c03cb0fdc95dca76fccc53db07fdf23d Mon Sep 17 00:00:00 2001 From: Bartek Tofel Date: Wed, 17 Apr 2024 14:22:28 +0200 Subject: [PATCH 06/19] use latest Seth (#12784) * use latest Seth * use latest Seth * update field names * update field names * adjust gitignore * gomodtidy --- .gitignore | 2 ++ integration-tests/actions/seth/actions.go | 4 +-- integration-tests/actions/seth/refund.go | 4 +-- integration-tests/experiments/gas_test.go | 9 ++++++ integration-tests/go.mod | 2 +- integration-tests/go.sum | 4 +-- integration-tests/load/go.mod | 2 +- integration-tests/load/go.sum | 4 +-- integration-tests/testconfig/default.toml | 34 +++++++++++------------ 9 files changed, 37 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index ccf8a006e7b..7d07300311f 100644 --- a/.gitignore +++ b/.gitignore @@ -97,3 +97,5 @@ override*.toml # Pythin venv .venv/ + +ocr_soak_report.csv \ No newline at end of file diff --git a/integration-tests/actions/seth/actions.go b/integration-tests/actions/seth/actions.go index d8daba3dbc3..7b128620158 100644 --- a/integration-tests/actions/seth/actions.go +++ b/integration-tests/actions/seth/actions.go @@ -147,9 +147,7 @@ func SendFunds(logger zerolog.Logger, client *seth.Client, payload FundsToSendPa if payload.GasTipCap != nil { gasTipCap = payload.GasTipCap } - } - - if !client.Cfg.Network.EIP1559DynamicFees { + } else { if payload.GasPrice == nil { txOptions := client.NewTXOpts((seth.WithGasLimit(gasLimit))) gasPrice = txOptions.GasPrice diff --git a/integration-tests/actions/seth/refund.go b/integration-tests/actions/seth/refund.go index 9a32d22de5d..cca4659cb6d 100644 --- a/integration-tests/actions/seth/refund.go +++ b/integration-tests/actions/seth/refund.go @@ -281,7 +281,7 @@ func ReturnFunds(log zerolog.Logger, sethClient *seth.Client, chainlinkNodes []c } // if not set, it will be just set to empty string, which is okay as long as gas estimation is disabled - txPriority := sethClient.Cfg.Network.GasEstimationTxPriority + txPriority := sethClient.Cfg.Network.GasPriceEstimationTxPriority txTimeout := sethClient.Cfg.Network.TxnTimeout.Duration() if sethClient.Cfg.IsExperimentEnabled(seth.Experiment_SlowFundsReturn) { @@ -291,7 +291,7 @@ func ReturnFunds(log zerolog.Logger, sethClient *seth.Client, chainlinkNodes []c } estimations := sethClient.CalculateGasEstimations(seth.GasEstimationRequest{ - GasEstimationEnabled: sethClient.Cfg.Network.GasEstimationEnabled, + GasEstimationEnabled: sethClient.Cfg.Network.GasPriceEstimationEnabled, FallbackGasPrice: sethClient.Cfg.Network.GasPrice, FallbackGasFeeCap: sethClient.Cfg.Network.GasFeeCap, FallbackGasTipCap: sethClient.Cfg.Network.GasTipCap, diff --git a/integration-tests/experiments/gas_test.go b/integration-tests/experiments/gas_test.go index b3ca8e53a25..ba096b69dbc 100644 --- a/integration-tests/experiments/gas_test.go +++ b/integration-tests/experiments/gas_test.go @@ -1,6 +1,7 @@ package experiments import ( + "math/big" "testing" "time" @@ -9,6 +10,7 @@ import ( "github.com/smartcontractkit/chainlink-testing-framework/logging" "github.com/smartcontractkit/chainlink-testing-framework/networks" + actions_seth "github.com/smartcontractkit/chainlink/integration-tests/actions/seth" "github.com/smartcontractkit/chainlink/integration-tests/contracts" tc "github.com/smartcontractkit/chainlink/integration-tests/testconfig" "github.com/smartcontractkit/chainlink/integration-tests/utils" @@ -31,6 +33,13 @@ func TestGasExperiment(t *testing.T) { seth, err := seth.NewClientWithConfig(&sethCfg) require.NoError(t, err, "Error creating seth client") + _, err = actions_seth.SendFunds(l, seth, actions_seth.FundsToSendPayload{ + ToAddress: seth.Addresses[0], + Amount: big.NewInt(10_000_000), + PrivateKey: seth.PrivateKeys[0], + }) + require.NoError(t, err, "Error sending funds") + for i := 0; i < 1; i++ { _, err = contracts.DeployLinkTokenContract(l, seth) require.NoError(t, err, "Error deploying LINK contract") diff --git a/integration-tests/go.mod b/integration-tests/go.mod index d1584fc1b7a..b3f0f7a90c6 100644 --- a/integration-tests/go.mod +++ b/integration-tests/go.mod @@ -29,7 +29,7 @@ require ( github.com/smartcontractkit/chainlink-vrf v0.0.0-20240222010609-cd67d123c772 github.com/smartcontractkit/chainlink/v2 v2.0.0-00010101000000-000000000000 github.com/smartcontractkit/libocr v0.0.0-20240326191951-2bbe9382d052 - github.com/smartcontractkit/seth v0.1.3 + github.com/smartcontractkit/seth v0.1.5 github.com/smartcontractkit/wasp v0.4.5 github.com/spf13/cobra v1.8.0 github.com/stretchr/testify v1.9.0 diff --git a/integration-tests/go.sum b/integration-tests/go.sum index 6aee95d7784..957e3988143 100644 --- a/integration-tests/go.sum +++ b/integration-tests/go.sum @@ -1543,8 +1543,8 @@ github.com/smartcontractkit/grpc-proxy v0.0.0-20230731113816-f1be6620749f h1:hgJ github.com/smartcontractkit/grpc-proxy v0.0.0-20230731113816-f1be6620749f/go.mod h1:MvMXoufZAtqExNexqi4cjrNYE9MefKddKylxjS+//n0= github.com/smartcontractkit/libocr v0.0.0-20240326191951-2bbe9382d052 h1:1WFjrrVrWoQ9UpVMh7Mx4jDpzhmo1h8hFUKd9awIhIU= github.com/smartcontractkit/libocr v0.0.0-20240326191951-2bbe9382d052/go.mod h1:SJEZCHgMCAzzBvo9vMV2DQ9onfEcIJCYSViyP4JI6c4= -github.com/smartcontractkit/seth v0.1.3 h1:pQc+SJeONWg73lQOiY5ZmBbvvVqEVBmTM9PiJOr+n4s= -github.com/smartcontractkit/seth v0.1.3/go.mod h1:2TMOZQ8WTAw7rR1YBbXpnad6VmT/+xDd/nXLmB7Eero= +github.com/smartcontractkit/seth v0.1.5 h1:tobdA3uzRHubN/ytE6bSq1dtvmaXjKdaHEGW5Re9I1U= +github.com/smartcontractkit/seth v0.1.5/go.mod h1:2TMOZQ8WTAw7rR1YBbXpnad6VmT/+xDd/nXLmB7Eero= github.com/smartcontractkit/tdh2/go/ocr2/decryptionplugin v0.0.0-20230906073235-9e478e5e19f1 h1:yiKnypAqP8l0OX0P3klzZ7SCcBUxy5KqTAKZmQOvSQE= github.com/smartcontractkit/tdh2/go/ocr2/decryptionplugin v0.0.0-20230906073235-9e478e5e19f1/go.mod h1:q6f4fe39oZPdsh1i57WznEZgxd8siidMaSFq3wdPmVg= github.com/smartcontractkit/tdh2/go/tdh2 v0.0.0-20230906073235-9e478e5e19f1 h1:Dai1bn+Q5cpeGMQwRdjOdVjG8mmFFROVkSKuUgBErRQ= diff --git a/integration-tests/load/go.mod b/integration-tests/load/go.mod index 85f673986f1..cbde975afcc 100644 --- a/integration-tests/load/go.mod +++ b/integration-tests/load/go.mod @@ -21,7 +21,7 @@ require ( github.com/smartcontractkit/chainlink/integration-tests v0.0.0-20240214231432-4ad5eb95178c github.com/smartcontractkit/chainlink/v2 v2.9.0-beta0.0.20240216210048-da02459ddad8 github.com/smartcontractkit/libocr v0.0.0-20240326191951-2bbe9382d052 - github.com/smartcontractkit/seth v0.1.3 + github.com/smartcontractkit/seth v0.1.5 github.com/smartcontractkit/tdh2/go/tdh2 v0.0.0-20230906073235-9e478e5e19f1 github.com/smartcontractkit/wasp v0.4.6 github.com/stretchr/testify v1.9.0 diff --git a/integration-tests/load/go.sum b/integration-tests/load/go.sum index cc8b1238a5a..2ae24681388 100644 --- a/integration-tests/load/go.sum +++ b/integration-tests/load/go.sum @@ -1528,8 +1528,8 @@ github.com/smartcontractkit/grpc-proxy v0.0.0-20230731113816-f1be6620749f h1:hgJ github.com/smartcontractkit/grpc-proxy v0.0.0-20230731113816-f1be6620749f/go.mod h1:MvMXoufZAtqExNexqi4cjrNYE9MefKddKylxjS+//n0= github.com/smartcontractkit/libocr v0.0.0-20240326191951-2bbe9382d052 h1:1WFjrrVrWoQ9UpVMh7Mx4jDpzhmo1h8hFUKd9awIhIU= github.com/smartcontractkit/libocr v0.0.0-20240326191951-2bbe9382d052/go.mod h1:SJEZCHgMCAzzBvo9vMV2DQ9onfEcIJCYSViyP4JI6c4= -github.com/smartcontractkit/seth v0.1.3 h1:pQc+SJeONWg73lQOiY5ZmBbvvVqEVBmTM9PiJOr+n4s= -github.com/smartcontractkit/seth v0.1.3/go.mod h1:2TMOZQ8WTAw7rR1YBbXpnad6VmT/+xDd/nXLmB7Eero= +github.com/smartcontractkit/seth v0.1.5 h1:tobdA3uzRHubN/ytE6bSq1dtvmaXjKdaHEGW5Re9I1U= +github.com/smartcontractkit/seth v0.1.5/go.mod h1:2TMOZQ8WTAw7rR1YBbXpnad6VmT/+xDd/nXLmB7Eero= github.com/smartcontractkit/tdh2/go/ocr2/decryptionplugin v0.0.0-20230906073235-9e478e5e19f1 h1:yiKnypAqP8l0OX0P3klzZ7SCcBUxy5KqTAKZmQOvSQE= github.com/smartcontractkit/tdh2/go/ocr2/decryptionplugin v0.0.0-20230906073235-9e478e5e19f1/go.mod h1:q6f4fe39oZPdsh1i57WznEZgxd8siidMaSFq3wdPmVg= github.com/smartcontractkit/tdh2/go/tdh2 v0.0.0-20230906073235-9e478e5e19f1 h1:Dai1bn+Q5cpeGMQwRdjOdVjG8mmFFROVkSKuUgBErRQ= diff --git a/integration-tests/testconfig/default.toml b/integration-tests/testconfig/default.toml index 92f8bcd7f80..a277c22b4c2 100644 --- a/integration-tests/testconfig/default.toml +++ b/integration-tests/testconfig/default.toml @@ -60,7 +60,7 @@ transfer_gas_fee = 21_000 # gas limit should be explicitly set only if you are connecting to a node that's incapable of estimating gas limit itself (should only happen for very old versions) # gas_limit = 8_000_000 -# manual settings, used when gas_estimation_enabled is false or when it fails +# manual settings, used when gas_price_estimation_enabled is false or when it fails # legacy transactions gas_price = 1_000_000_000 @@ -77,11 +77,11 @@ eip_1559_dynamic_fees = true # automated gas estimation for live networks # if set to true we will dynamically estimate gas for every transaction (based on suggested values, priority and congestion rate for last X blocks) -# gas_estimation_enabled = true +# gas_price_estimation_enabled = true # number of blocks to use for congestion rate estimation (it will determine buffer added on top of suggested values) -# gas_estimation_blocks = 100 +# gas_price_estimation_blocks = 100 # transaction priority, which determines adjustment factor multiplier applied to suggested values (fast - 1.2x, standard - 1x, slow - 0.8x) -# gas_estimation_tx_priority = "standard" +# gas_price_estimation_tx_priority = "standard" # URLs # if set they will overwrite URLs from EVMNetwork that Seth uses, can be either WS(S) or HTTP(S) @@ -94,7 +94,7 @@ eip_1559_dynamic_fees = true # we use hardcoded value in order to be estimate how much funds are available for sending or returning after tx costs have been paid transfer_gas_fee = 21_000 -# manual settings, used when gas_estimation_enabled is false or when it fails +# manual settings, used when gas_price_estimation_enabled is false or when it fails # legacy transactions gas_price = 30_000_000_000 @@ -110,11 +110,11 @@ eip_1559_dynamic_fees = false # automated gas estimation for live networks # if set to true we will dynamically estimate gas for every transaction (based on suggested values, priority and congestion rate for last X blocks) -# gas_estimation_enabled = true +# gas_price_estimation_enabled = true # number of blocks to use for congestion rate estimation (it will determine buffer added on top of suggested values) -# gas_estimation_blocks = 100 +# gas_price_estimation_blocks = 100 # transaction priority, which determines adjustment factor multiplier applied to suggested values (fast - 1.2x, standard - 1x, slow - 0.8x) -# gas_estimation_tx_priority = "standard" +# gas_price_estimation_tx_priority = "standard" # URLs # if set they will overwrite URLs from EVMNetwork that Seth uses, can be either WS(S) or HTTP(S) @@ -127,7 +127,7 @@ eip_1559_dynamic_fees = false # we use hardcoded value in order to be estimate how much funds are available for sending or returning after tx costs have been paid transfer_gas_fee = 21_000 -# manual settings, used when gas_estimation_enabled is false or when it fails +# manual settings, used when gas_price_estimation_enabled is false or when it fails # legacy transactions gas_price = 50_000_000_000 @@ -143,11 +143,11 @@ eip_1559_dynamic_fees = true # automated gas estimation for live networks # if set to true we will dynamically estimate gas for every transaction (based on suggested values, priority and congestion rate for last X blocks) -# gas_estimation_enabled = true +# gas_price_estimation_enabled = true # number of blocks to use for congestion rate estimation (it will determine buffer added on top of suggested values) -# gas_estimation_blocks = 100 +# gas_price_estimation_blocks = 100 # transaction priority, which determines adjustment factor multiplier applied to suggested values (fast - 1.2x, standard - 1x, slow - 0.8x) -# gas_estimation_tx_priority = "standard" +# gas_price_estimation_tx_priority = "standard" # URLs # if set they will overwrite URLs from EVMNetwork that Seth uses, can be either WS(S) or HTTP(S) @@ -160,7 +160,7 @@ eip_1559_dynamic_fees = true # we use hardcoded value in order to be estimate how much funds are available for sending or returning after tx costs have been paid transfer_gas_fee = 21_000 -# manual settings, used when gas_estimation_enabled is false or when it fails +# manual settings, used when gas_price_estimation_enabled is false or when it fails # legacy transactions gas_price = 1_800_000_000 @@ -176,11 +176,11 @@ eip_1559_dynamic_fees = false # automated gas estimation for live networks # if set to true we will dynamically estimate gas for every transaction (based on suggested values, priority and congestion rate for last X blocks) -# gas_estimation_enabled = true +# gas_price_estimation_enabled = true # number of blocks to use for congestion rate estimation (it will determine buffer added on top of suggested values) -# gas_estimation_blocks = 100 +# gas_price_estimation_blocks = 100 # transaction priority, which determines adjustment factor multiplier applied to suggested values (fast - 1.2x, standard - 1x, slow - 0.8x) -# gas_estimation_tx_priority = "standard" +# gas_price_estimation_tx_priority = "standard" # URLs # if set they will overwrite URLs from EVMNetwork that Seth uses, can be either WS(S) or HTTP(S) @@ -193,7 +193,7 @@ eip_1559_dynamic_fees = false # we use hardcoded value in order to be estimate how much funds are available for sending or returning after tx costs have been paid transfer_gas_fee = 21_000 -# manual settings, used when gas_estimation_enabled is false or when it fails +# manual settings, used when gas_price_estimation_enabled is false or when it fails # legacy transactions gas_price = 50_000_000 From 848c5334419026300db28c9c654e8d0df7e7254a Mon Sep 17 00:00:00 2001 From: Tate Date: Wed, 17 Apr 2024 09:18:22 -0600 Subject: [PATCH 07/19] Remove defaults in json files since the defaults are supposed to come from the script (#12804) --- .../smoke/automation_test.go_test_list.json | 13 --- .../evm_node_compatibility_test_list.json | 90 +++++++------------ .../smoke/keeper_test.go_test_list.json | 9 -- .../smoke/log_poller_test.go_test_list.json | 30 +++---- 4 files changed, 40 insertions(+), 102 deletions(-) diff --git a/integration-tests/smoke/automation_test.go_test_list.json b/integration-tests/smoke/automation_test.go_test_list.json index 3e7a82effd3..03029c9018b 100644 --- a/integration-tests/smoke/automation_test.go_test_list.json +++ b/integration-tests/smoke/automation_test.go_test_list.json @@ -2,7 +2,6 @@ "tests": [ { "name": "TestAutomationBasic", - "label": "ubuntu-latest", "nodes": 3, "run":[ {"name":"registry_2_0"}, @@ -12,7 +11,6 @@ }, { "name": "TestAutomationBasic", - "label": "ubuntu-latest", "nodes": 3, "run":[ {"name":"registry_2_1_with_mercury_v02"}, @@ -22,7 +20,6 @@ }, { "name": "TestAutomationBasic", - "label": "ubuntu-latest", "nodes": 3, "run":[ {"name":"registry_2_2_conditional"}, @@ -32,7 +29,6 @@ }, { "name": "TestAutomationBasic", - "label": "ubuntu-latest", "nodes": 2, "run":[ {"name":"registry_2_2_with_mercury_v03"}, @@ -41,47 +37,38 @@ }, { "name": "TestSetUpkeepTriggerConfig", - "label": "ubuntu-latest", "nodes": 2 }, { "name": "TestAutomationAddFunds", - "label": "ubuntu-latest", "nodes": 3 }, { "name": "TestAutomationPauseUnPause", - "label": "ubuntu-latest", "nodes": 3 }, { "name": "TestAutomationRegisterUpkeep", - "label": "ubuntu-latest", "nodes": 3 }, { "name": "TestAutomationPauseRegistry", - "label": "ubuntu-latest", "nodes": 3 }, { "name": "TestAutomationKeeperNodesDown", - "label": "ubuntu-latest", "nodes": 3 }, { "name": "TestAutomationPerformSimulation", - "label": "ubuntu-latest", "nodes": 3 }, { "name": "TestAutomationCheckPerformGasLimit", - "label": "ubuntu-latest", "nodes": 3 }, { "name": "TestUpdateCheckData", - "label": "ubuntu-latest", "nodes": 3 } ] diff --git a/integration-tests/smoke/evm_node_compatibility_test_list.json b/integration-tests/smoke/evm_node_compatibility_test_list.json index c14a2b54a3e..45b303a0a27 100644 --- a/integration-tests/smoke/evm_node_compatibility_test_list.json +++ b/integration-tests/smoke/evm_node_compatibility_test_list.json @@ -4,211 +4,181 @@ "product": "ocr", "name": "TestOCRBasic", "eth_client": "geth", - "docker_image": "ethereum/client-go:latest_stable", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:latest_stable" }, { "product": "ocr", "name": "TestOCRBasic", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.13.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.13.0" }, { "product": "ocr", "name": "TestOCRBasic", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.12.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.12.0" }, { "product": "ocr", "name": "TestOCRBasic", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.11.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.11.0" }, { "product": "ocr", "name": "TestOCRBasic", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.10.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.10.0" }, { "product": "ocr2", "name": "TestOCRv2Request", "eth_client": "geth", - "docker_image": "ethereum/client-go:latest_stable", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:latest_stable" }, { "product": "ocr2", "name": "TestOCRv2Request", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.13.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.13.0" }, { "product": "ocr2", "name": "TestOCRv2Request", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.12.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.12.0" }, { "product": "ocr2", "name": "TestOCRv2Request", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.11.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.11.0" }, { "product": "ocr2", "name": "TestOCRv2Request", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.10.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.10.0" }, { "product": "vrf", "name": "TestVRFBasic", "eth_client": "geth", - "docker_image": "ethereum/client-go:latest_stable", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:latest_stable" }, { "product": "vrf", "name": "TestVRFBasic", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.13.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.13.0" }, { "product": "vrf", "name": "TestVRFBasic", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.12.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.12.0" }, { "product": "vrf", "name": "TestVRFBasic", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.11.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.11.0" }, { "product": "vrf", "name": "TestVRFBasic", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.10.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.10.0" }, { "product": "vrfv2", "name": "TestVRFv2Basic/Request Randomness", "eth_client": "geth", - "docker_image": "ethereum/client-go:latest_stable", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:latest_stable" }, { "product": "vrfv2", "name": "TestVRFv2Basic/Request Randomness", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.13.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.13.0" }, { "product": "vrfv2", "name": "TestVRFv2Basic/Request Randomness", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.12.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.12.0" }, { "product": "vrfv2", "name": "TestVRFv2Basic/Request Randomness", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.11.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.11.0" }, { "product": "vrfv2", "name": "TestVRFv2Basic/Request Randomness", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.10.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.10.0" }, { "product": "vrfv2plus", "name": "TestVRFv2Plus/Link Billing", "eth_client": "geth", - "docker_image": "ethereum/client-go:latest_stable", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:latest_stable" }, { "product": "vrfv2plus", "name": "TestVRFv2Plus/Link Billing", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.13.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.13.0" }, { "product": "vrfv2plus", "name": "TestVRFv2Plus/Link Billing", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.12.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.12.0" }, { "product": "vrfv2plus", "name": "TestVRFv2Plus/Link Billing", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.11.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.11.0" }, { "product": "vrfv2plus", "name": "TestVRFv2Plus/Link Billing", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.10.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.10.0" }, { "product": "automation", "name": "TestSetUpkeepTriggerConfig", "eth_client": "geth", - "docker_image": "ethereum/client-go:latest_stable", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:latest_stable" }, { "product": "automation", "name": "TestSetUpkeepTriggerConfig", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.13.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.13.0" }, { "product": "automation", "name": "TestSetUpkeepTriggerConfig", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.12.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.12.0" }, { "product": "automation", "name": "TestSetUpkeepTriggerConfig", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.11.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.11.0" }, { "product": "automation", "name": "TestSetUpkeepTriggerConfig", "eth_client": "geth", - "docker_image": "ethereum/client-go:v1.10.0", - "label": "ubuntu-latest" + "docker_image": "ethereum/client-go:v1.10.0" } ] } \ No newline at end of file diff --git a/integration-tests/smoke/keeper_test.go_test_list.json b/integration-tests/smoke/keeper_test.go_test_list.json index b2f4aa00659..b9ccaa0c008 100644 --- a/integration-tests/smoke/keeper_test.go_test_list.json +++ b/integration-tests/smoke/keeper_test.go_test_list.json @@ -2,42 +2,34 @@ "tests": [ { "name": "TestKeeperBasicSmoke", - "label": "ubuntu-latest", "nodes": 3 }, { "name": "TestKeeperBlockCountPerTurn", - "label": "ubuntu-latest", "nodes": 3 }, { "name": "TestKeeperSimulation", - "label": "ubuntu-latest", "nodes": 2 }, { "name": "TestKeeperCheckPerformGasLimit", - "label": "ubuntu-latest", "nodes": 3 }, { "name": "TestKeeperRegisterUpkeep", - "label": "ubuntu-latest", "nodes": 3 }, { "name": "TestKeeperAddFunds", - "label": "ubuntu-latest", "nodes": 3 }, { "name": "TestKeeperRemove", - "label": "ubuntu-latest", "nodes": 3 }, { "name": "TestKeeperPauseRegistry", - "label": "ubuntu-latest", "nodes": 2 }, { @@ -45,7 +37,6 @@ }, { "name": "TestKeeperNodeDown", - "label": "ubuntu-latest", "nodes": 3 }, { diff --git a/integration-tests/smoke/log_poller_test.go_test_list.json b/integration-tests/smoke/log_poller_test.go_test_list.json index 2159654e283..96939c5133b 100644 --- a/integration-tests/smoke/log_poller_test.go_test_list.json +++ b/integration-tests/smoke/log_poller_test.go_test_list.json @@ -1,44 +1,34 @@ { "tests": [ { - "name": "TestLogPollerFewFiltersFixedDepth", - "label": "ubuntu-latest" + "name": "TestLogPollerFewFiltersFixedDepth" }, { - "name": "TestLogPollerFewFiltersFinalityTag", - "label": "ubuntu-latest" + "name": "TestLogPollerFewFiltersFinalityTag" }, { - "name": "TestLogPollerWithChaosFixedDepth", - "label": "ubuntu-latest" + "name": "TestLogPollerWithChaosFixedDepth" }, { - "name": "TestLogPollerWithChaosFinalityTag", - "label": "ubuntu-latest" + "name": "TestLogPollerWithChaosFinalityTag" }, { - "name": "TestLogPollerWithChaosPostgresFinalityTag", - "label": "ubuntu-latest" + "name": "TestLogPollerWithChaosPostgresFinalityTag" }, { - "name": "TestLogPollerWithChaosPostgresFixedDepth", - "label": "ubuntu-latest" + "name": "TestLogPollerWithChaosPostgresFixedDepth" }, { - "name": "TestLogPollerReplayFixedDepth", - "label": "ubuntu-latest" + "name": "TestLogPollerReplayFixedDepth" }, { - "name": "TestLogPollerReplayFinalityTag", - "label": "ubuntu-latest" + "name": "TestLogPollerReplayFinalityTag" }, { - "name": "TestLogPollerManyFiltersFixedDepth", - "label": "ubuntu-latest" + "name": "TestLogPollerManyFiltersFixedDepth" }, { - "name": "TestLogPollerManyFiltersFinalityTag", - "label": "ubuntu-latest" + "name": "TestLogPollerManyFiltersFinalityTag" } ] } \ No newline at end of file From 99443c53bfe5ec4213e9a8b05ace83c94351e896 Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Wed, 17 Apr 2024 10:42:51 -0500 Subject: [PATCH 08/19] go.mod: combine indirect blocks (#12860) --- go.mod | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index 5a872c77e76..69b389230ac 100644 --- a/go.mod +++ b/go.mod @@ -113,14 +113,6 @@ require ( gopkg.in/natefinch/lumberjack.v2 v2.2.1 ) -require ( - github.com/bahlo/generic-list-go v0.2.0 // indirect - github.com/buger/jsonparser v1.1.1 // indirect - github.com/mailru/easyjson v0.7.7 // indirect - github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect -) - require ( contrib.go.opencensus.io/exporter/stackdriver v0.13.5 // indirect cosmossdk.io/api v0.3.1 // indirect @@ -141,11 +133,13 @@ require ( github.com/VictoriaMetrics/fastcache v1.12.1 // indirect github.com/armon/go-metrics v0.4.1 // indirect github.com/aybabtme/rgbterm v0.0.0-20170906152045-cc83f3b3ce59 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/benbjohnson/clock v1.3.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bgentry/speakeasy v0.1.1-0.20220910012023-760eaf8b6816 // indirect github.com/bits-and-blooms/bitset v1.10.0 // indirect github.com/blendle/zapdriver v1.3.1 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/bytedance/sonic v1.10.1 // indirect github.com/cenkalti/backoff v2.2.1+incompatible // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect @@ -260,6 +254,7 @@ require ( github.com/libp2p/go-buffer-pool v0.1.0 // indirect github.com/logrusorgru/aurora v2.0.3+incompatible // indirect github.com/magiconair/properties v1.8.7 // indirect + github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.14 // indirect @@ -313,6 +308,7 @@ require ( github.com/tyler-smith/go-bip39 v1.1.0 // indirect github.com/umbracle/fastrlp v0.0.0-20220527094140-59d5dd30e722 // indirect github.com/valyala/fastjson v1.4.1 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect github.com/zondax/hid v0.9.1 // indirect @@ -340,6 +336,7 @@ require ( gopkg.in/guregu/null.v2 v2.1.2 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect nhooyr.io/websocket v1.8.7 // indirect pgregory.net/rapid v0.5.5 // indirect rsc.io/tmplfunc v0.0.3 // indirect From 0ec92765ccd419973f4eab5b0cc38df212f4ad21 Mon Sep 17 00:00:00 2001 From: Jordan Krage Date: Wed, 17 Apr 2024 10:58:58 -0500 Subject: [PATCH 09/19] switch more EVM components to use sqlutil.DataStore (#12856) --- .changeset/soft-hotels-decide.md | 5 + common/txmgr/broadcaster.go | 2 +- common/txmgr/confirmer.go | 2 +- common/txmgr/txmgr.go | 2 +- core/bridges/orm_test.go | 5 +- .../evm/forwarders/forwarder_manager.go | 4 +- core/chains/evm/forwarders/orm.go | 46 +-- core/chains/evm/headtracker/orm.go | 16 +- core/chains/evm/log/broadcaster.go | 56 ++-- core/chains/evm/log/helpers_test.go | 2 +- core/chains/evm/log/mocks/broadcaster.go | 30 +- core/chains/evm/log/orm.go | 74 ++--- core/chains/evm/log/orm_test.go | 66 +--- core/chains/evm/log/registrations.go | 7 +- core/chains/evm/log/registrations_test.go | 5 +- core/chains/evm/txmgr/broadcaster_test.go | 4 +- core/chains/evm/txmgr/confirmer_test.go | 11 +- core/internal/cltest/cltest.go | 2 +- core/internal/cltest/factories.go | 2 +- core/internal/cltest/job_factories.go | 6 +- core/internal/features/features_test.go | 2 +- core/services/blockhashstore/delegate.go | 3 +- core/services/blockheaderfeeder/delegate.go | 3 +- core/services/chainlink/application.go | 5 +- core/services/cron/cron_test.go | 2 +- core/services/cron/delegate.go | 9 +- core/services/directrequest/delegate.go | 16 +- core/services/directrequest/delegate_test.go | 85 +++--- core/services/feeds/orm_test.go | 2 +- core/services/feeds/service.go | 2 +- core/services/fluxmonitorv2/delegate.go | 14 +- core/services/fluxmonitorv2/flux_monitor.go | 82 ++--- .../fluxmonitorv2/flux_monitor_test.go | 185 ++++++------ core/services/fluxmonitorv2/mocks/orm.go | 103 ++++--- core/services/fluxmonitorv2/orm.go | 60 ++-- core/services/fluxmonitorv2/orm_test.go | 38 +-- core/services/gateway/delegate.go | 8 +- core/services/job/job_orm_test.go | 43 +-- .../job/job_pipeline_orm_integration_test.go | 7 +- core/services/job/kv_orm_test.go | 2 +- core/services/job/orm.go | 2 +- core/services/job/runner_integration_test.go | 6 +- core/services/job/spawner.go | 14 +- core/services/job/spawner_test.go | 10 +- core/services/keeper/delegate.go | 9 +- core/services/keeper/integration_test.go | 12 +- .../keeper/registry1_1_synchronizer_test.go | 25 +- .../keeper/registry1_2_synchronizer_test.go | 40 ++- .../keeper/registry1_3_synchronizer_test.go | 54 ++-- .../registry_synchronizer_log_listener.go | 3 +- .../registry_synchronizer_process_logs.go | 2 +- core/services/ocr/contract_tracker.go | 40 ++- core/services/ocr/contract_tracker_test.go | 35 ++- core/services/ocr/database.go | 41 +-- core/services/ocr/database_test.go | 10 +- core/services/ocr/delegate.go | 12 +- core/services/ocr/helpers_internal_test.go | 3 +- .../ocr/mocks/ocr_contract_tracker_db.go | 32 +- core/services/ocr2/delegate.go | 6 +- .../generic/pipeline_runner_adapter_test.go | 2 +- core/services/ocrbootstrap/delegate.go | 3 +- core/services/ocrcommon/run_saver.go | 22 +- core/services/ocrcommon/run_saver_test.go | 2 +- core/services/pipeline/helpers_test.go | 3 + core/services/pipeline/mocks/orm.go | 283 +++++++++--------- core/services/pipeline/mocks/runner.go | 56 ++-- core/services/pipeline/orm.go | 278 +++++++++-------- core/services/pipeline/orm_test.go | 69 +++-- core/services/pipeline/runner.go | 50 ++-- core/services/pipeline/runner_test.go | 35 ++- core/services/pipeline/task.bridge_test.go | 56 ++-- core/services/pipeline/task.http_test.go | 5 +- .../relay/evm/mocks/request_round_db.go | 20 +- core/services/relay/evm/request_round_db.go | 11 +- .../relay/evm/request_round_db_test.go | 8 +- .../relay/evm/request_round_tracker.go | 19 +- .../relay/evm/request_round_tracker_test.go | 37 +-- core/services/streams/delegate.go | 9 +- core/services/streams/stream_test.go | 4 +- core/services/vrf/delegate.go | 24 +- core/services/vrf/delegate_test.go | 34 ++- core/services/vrf/v1/integration_test.go | 6 +- core/services/vrf/v1/listener_v1.go | 19 +- .../vrf/v2/integration_helpers_test.go | 109 ++++--- .../vrf/v2/integration_v2_plus_test.go | 3 +- core/services/vrf/v2/integration_v2_test.go | 42 +-- core/services/vrf/v2/listener_v2.go | 8 +- .../vrf/v2/listener_v2_log_processor.go | 92 +++--- core/services/vrf/v2/listener_v2_types.go | 6 +- core/services/vrf/v2/reverted_txns.go | 29 +- core/services/webhook/delegate.go | 3 +- core/services/workflows/delegate.go | 3 +- core/store/migrate/migrate_test.go | 11 +- core/web/pipeline_runs_controller.go | 6 +- core/web/resolver/job_run_test.go | 4 +- core/web/resolver/mutation.go | 2 +- 96 files changed, 1363 insertions(+), 1384 deletions(-) create mode 100644 .changeset/soft-hotels-decide.md diff --git a/.changeset/soft-hotels-decide.md b/.changeset/soft-hotels-decide.md new file mode 100644 index 00000000000..75b4cadd4e5 --- /dev/null +++ b/.changeset/soft-hotels-decide.md @@ -0,0 +1,5 @@ +--- +"chainlink": patch +--- + +switch more EVM components to use sqlutil.DataStore #internal diff --git a/common/txmgr/broadcaster.go b/common/txmgr/broadcaster.go index a13673bf91b..1651f6417bf 100644 --- a/common/txmgr/broadcaster.go +++ b/common/txmgr/broadcaster.go @@ -689,7 +689,7 @@ func (eb *Broadcaster[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) save // is relatively benign and probably nobody will ever run into it in // practice, but something to be aware of. if etx.PipelineTaskRunID.Valid && eb.resumeCallback != nil && etx.SignalCallback { - err := eb.resumeCallback(etx.PipelineTaskRunID.UUID, nil, fmt.Errorf("fatal error while sending transaction: %s", etx.Error.String)) + err := eb.resumeCallback(ctx, etx.PipelineTaskRunID.UUID, nil, fmt.Errorf("fatal error while sending transaction: %s", etx.Error.String)) if errors.Is(err, sql.ErrNoRows) { lgr.Debugw("callback missing or already resumed", "etxID", etx.ID) } else if err != nil { diff --git a/common/txmgr/confirmer.go b/common/txmgr/confirmer.go index 53e1c3c4206..d61f9a3dddd 100644 --- a/common/txmgr/confirmer.go +++ b/common/txmgr/confirmer.go @@ -1120,7 +1120,7 @@ func (ec *Confirmer[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Res } ec.lggr.Debugw("Callback: resuming tx with receipt", "output", output, "taskErr", taskErr, "pipelineTaskRunID", data.ID) - if err := ec.resumeCallback(data.ID, output, taskErr); err != nil { + if err := ec.resumeCallback(ctx, data.ID, output, taskErr); err != nil { return fmt.Errorf("failed to resume suspended pipeline run: %w", err) } // Mark tx as having completed callback diff --git a/common/txmgr/txmgr.go b/common/txmgr/txmgr.go index d183a8c3ade..b996b76f1a5 100644 --- a/common/txmgr/txmgr.go +++ b/common/txmgr/txmgr.go @@ -27,7 +27,7 @@ import ( // https://www.notion.so/chainlink/Txm-Architecture-Overview-9dc62450cd7a443ba9e7dceffa1a8d6b // ResumeCallback is assumed to be idempotent -type ResumeCallback func(id uuid.UUID, result interface{}, err error) error +type ResumeCallback func(ctx context.Context, id uuid.UUID, result interface{}, err error) error // TxManager is the main component of the transaction manager. // It is also the interface to external callers. diff --git a/core/bridges/orm_test.go b/core/bridges/orm_test.go index 204dc5fe115..85e8b9ecdef 100644 --- a/core/bridges/orm_test.go +++ b/core/bridges/orm_test.go @@ -17,7 +17,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/store/models" ) @@ -144,8 +143,8 @@ func TestORM_TestCachedResponse(t *testing.T) { db := pgtest.NewSqlxDB(t) orm := bridges.NewORM(db) - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(ctx, nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) _, err = orm.GetCachedResponse(ctx, "dot", specID, 1*time.Second) diff --git a/core/chains/evm/forwarders/forwarder_manager.go b/core/chains/evm/forwarders/forwarder_manager.go index f0786c091c4..7a7a274127f 100644 --- a/core/chains/evm/forwarders/forwarder_manager.go +++ b/core/chains/evm/forwarders/forwarder_manager.go @@ -54,13 +54,13 @@ type FwdMgr struct { wg sync.WaitGroup } -func NewFwdMgr(db sqlutil.DataSource, client evmclient.Client, logpoller evmlogpoller.LogPoller, l logger.Logger, cfg Config) *FwdMgr { +func NewFwdMgr(ds sqlutil.DataSource, client evmclient.Client, logpoller evmlogpoller.LogPoller, l logger.Logger, cfg Config) *FwdMgr { lggr := logger.Sugared(logger.Named(l, "EVMForwarderManager")) fwdMgr := FwdMgr{ logger: lggr, cfg: cfg, evmClient: client, - ORM: NewORM(db), + ORM: NewORM(ds), logpoller: logpoller, sendersCache: make(map[common.Address][]common.Address), } diff --git a/core/chains/evm/forwarders/orm.go b/core/chains/evm/forwarders/orm.go index cf498518d6d..8076cba4831 100644 --- a/core/chains/evm/forwarders/orm.go +++ b/core/chains/evm/forwarders/orm.go @@ -23,50 +23,50 @@ type ORM interface { FindForwardersInListByChain(ctx context.Context, evmChainId big.Big, addrs []common.Address) ([]Forwarder, error) } -type DbORM struct { - db sqlutil.DataSource +type DSORM struct { + ds sqlutil.DataSource } -var _ ORM = &DbORM{} +var _ ORM = &DSORM{} -func NewORM(db sqlutil.DataSource) *DbORM { - return &DbORM{db: db} +func NewORM(ds sqlutil.DataSource) *DSORM { + return &DSORM{ds: ds} } -func (o *DbORM) Transaction(ctx context.Context, fn func(*DbORM) error) (err error) { - return sqlutil.Transact(ctx, o.new, o.db, nil, fn) +func (o *DSORM) Transact(ctx context.Context, fn func(*DSORM) error) (err error) { + return sqlutil.Transact(ctx, o.new, o.ds, nil, fn) } // new returns a NewORM like o, but backed by q. -func (o *DbORM) new(q sqlutil.DataSource) *DbORM { return NewORM(q) } +func (o *DSORM) new(q sqlutil.DataSource) *DSORM { return NewORM(q) } // CreateForwarder creates the Forwarder address associated with the current EVM chain id. -func (o *DbORM) CreateForwarder(ctx context.Context, addr common.Address, evmChainId big.Big) (fwd Forwarder, err error) { +func (o *DSORM) CreateForwarder(ctx context.Context, addr common.Address, evmChainId big.Big) (fwd Forwarder, err error) { sql := `INSERT INTO evm.forwarders (address, evm_chain_id, created_at, updated_at) VALUES ($1, $2, now(), now()) RETURNING *` - err = o.db.GetContext(ctx, &fwd, sql, addr, evmChainId) + err = o.ds.GetContext(ctx, &fwd, sql, addr, evmChainId) return fwd, err } // DeleteForwarder removes a forwarder address. // If cleanup is non-nil, it can be used to perform any chain- or contract-specific cleanup that need to happen atomically // on forwarder deletion. If cleanup returns an error, forwarder deletion will be aborted. -func (o *DbORM) DeleteForwarder(ctx context.Context, id int64, cleanup func(tx sqlutil.DataSource, evmChainID int64, addr common.Address) error) (err error) { - return o.Transaction(ctx, func(orm *DbORM) error { +func (o *DSORM) DeleteForwarder(ctx context.Context, id int64, cleanup func(tx sqlutil.DataSource, evmChainID int64, addr common.Address) error) (err error) { + return o.Transact(ctx, func(orm *DSORM) error { var dest struct { EvmChainId int64 Address common.Address } - err := orm.db.GetContext(ctx, &dest, `SELECT evm_chain_id, address FROM evm.forwarders WHERE id = $1`, id) + err := orm.ds.GetContext(ctx, &dest, `SELECT evm_chain_id, address FROM evm.forwarders WHERE id = $1`, id) if err != nil { return err } if cleanup != nil { - if err = cleanup(orm.db, dest.EvmChainId, dest.Address); err != nil { + if err = cleanup(orm.ds, dest.EvmChainId, dest.Address); err != nil { return err } } - result, err := orm.db.ExecContext(ctx, `DELETE FROM evm.forwarders WHERE id = $1`, id) + result, err := orm.ds.ExecContext(ctx, `DELETE FROM evm.forwarders WHERE id = $1`, id) // If the forwarder wasn't found, we still want to delete the filter. // In that case, the transaction must return nil, even though DeleteForwarder // will return sql.ErrNoRows @@ -82,27 +82,27 @@ func (o *DbORM) DeleteForwarder(ctx context.Context, id int64, cleanup func(tx s } // FindForwarders returns all forwarder addresses from offset up until limit. -func (o *DbORM) FindForwarders(ctx context.Context, offset, limit int) (fwds []Forwarder, count int, err error) { +func (o *DSORM) FindForwarders(ctx context.Context, offset, limit int) (fwds []Forwarder, count int, err error) { sql := `SELECT count(*) FROM evm.forwarders` - if err = o.db.GetContext(ctx, &count, sql); err != nil { + if err = o.ds.GetContext(ctx, &count, sql); err != nil { return } sql = `SELECT * FROM evm.forwarders ORDER BY created_at DESC, id DESC LIMIT $1 OFFSET $2` - if err = o.db.SelectContext(ctx, &fwds, sql, limit, offset); err != nil { + if err = o.ds.SelectContext(ctx, &fwds, sql, limit, offset); err != nil { return } return } // FindForwardersByChain returns all forwarder addresses for a chain. -func (o *DbORM) FindForwardersByChain(ctx context.Context, evmChainId big.Big) (fwds []Forwarder, err error) { +func (o *DSORM) FindForwardersByChain(ctx context.Context, evmChainId big.Big) (fwds []Forwarder, err error) { sql := `SELECT * FROM evm.forwarders where evm_chain_id = $1 ORDER BY created_at DESC, id DESC` - err = o.db.SelectContext(ctx, &fwds, sql, evmChainId) + err = o.ds.SelectContext(ctx, &fwds, sql, evmChainId) return } -func (o *DbORM) FindForwardersInListByChain(ctx context.Context, evmChainId big.Big, addrs []common.Address) ([]Forwarder, error) { +func (o *DSORM) FindForwardersInListByChain(ctx context.Context, evmChainId big.Big, addrs []common.Address) ([]Forwarder, error) { var fwdrs []Forwarder arg := map[string]interface{}{ @@ -127,8 +127,8 @@ func (o *DbORM) FindForwardersInListByChain(ctx context.Context, evmChainId big. return nil, pkgerrors.Wrap(err, "Failed to run sqlx.IN on query") } - query = o.db.Rebind(query) - err = o.db.SelectContext(ctx, &fwdrs, query, args...) + query = o.ds.Rebind(query) + err = o.ds.SelectContext(ctx, &fwdrs, query, args...) if err != nil { return nil, pkgerrors.Wrap(err, "Failed to execute query") diff --git a/core/chains/evm/headtracker/orm.go b/core/chains/evm/headtracker/orm.go index 8912bafecdf..9d569ade08d 100644 --- a/core/chains/evm/headtracker/orm.go +++ b/core/chains/evm/headtracker/orm.go @@ -31,14 +31,14 @@ var _ ORM = &DbORM{} type DbORM struct { chainID ubig.Big - db sqlutil.DataSource + ds sqlutil.DataSource } // NewORM creates an ORM scoped to chainID. -func NewORM(chainID big.Int, db sqlutil.DataSource) *DbORM { +func NewORM(chainID big.Int, ds sqlutil.DataSource) *DbORM { return &DbORM{ chainID: ubig.Big(chainID), - db: db, + ds: ds, } } @@ -48,19 +48,19 @@ func (orm *DbORM) IdempotentInsertHead(ctx context.Context, head *evmtypes.Head) INSERT INTO evm.heads (hash, number, parent_hash, created_at, timestamp, l1_block_number, evm_chain_id, base_fee_per_gas) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (evm_chain_id, hash) DO NOTHING` - _, err := orm.db.ExecContext(ctx, query, head.Hash, head.Number, head.ParentHash, head.CreatedAt, head.Timestamp, head.L1BlockNumber, orm.chainID, head.BaseFeePerGas) + _, err := orm.ds.ExecContext(ctx, query, head.Hash, head.Number, head.ParentHash, head.CreatedAt, head.Timestamp, head.L1BlockNumber, orm.chainID, head.BaseFeePerGas) return pkgerrors.Wrap(err, "IdempotentInsertHead failed to insert head") } func (orm *DbORM) TrimOldHeads(ctx context.Context, minBlockNumber int64) (err error) { query := `DELETE FROM evm.heads WHERE evm_chain_id = $1 AND number < $2` - _, err = orm.db.ExecContext(ctx, query, orm.chainID, minBlockNumber) + _, err = orm.ds.ExecContext(ctx, query, orm.chainID, minBlockNumber) return err } func (orm *DbORM) LatestHead(ctx context.Context) (head *evmtypes.Head, err error) { head = new(evmtypes.Head) - err = orm.db.GetContext(ctx, head, `SELECT * FROM evm.heads WHERE evm_chain_id = $1 ORDER BY number DESC, created_at DESC, id DESC LIMIT 1`, orm.chainID) + err = orm.ds.GetContext(ctx, head, `SELECT * FROM evm.heads WHERE evm_chain_id = $1 ORDER BY number DESC, created_at DESC, id DESC LIMIT 1`, orm.chainID) if pkgerrors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -69,14 +69,14 @@ func (orm *DbORM) LatestHead(ctx context.Context) (head *evmtypes.Head, err erro } func (orm *DbORM) LatestHeads(ctx context.Context, minBlockNumer int64) (heads []*evmtypes.Head, err error) { - err = orm.db.SelectContext(ctx, &heads, `SELECT * FROM evm.heads WHERE evm_chain_id = $1 AND number >= $2 ORDER BY number DESC, created_at DESC, id DESC`, orm.chainID, minBlockNumer) + err = orm.ds.SelectContext(ctx, &heads, `SELECT * FROM evm.heads WHERE evm_chain_id = $1 AND number >= $2 ORDER BY number DESC, created_at DESC, id DESC`, orm.chainID, minBlockNumer) err = pkgerrors.Wrap(err, "LatestHeads failed") return } func (orm *DbORM) HeadByHash(ctx context.Context, hash common.Hash) (head *evmtypes.Head, err error) { head = new(evmtypes.Head) - err = orm.db.GetContext(ctx, head, `SELECT * FROM evm.heads WHERE evm_chain_id = $1 AND hash = $2`, orm.chainID, hash) + err = orm.ds.GetContext(ctx, head, `SELECT * FROM evm.heads WHERE evm_chain_id = $1 AND hash = $2`, orm.chainID, hash) if pkgerrors.Is(err, sql.ErrNoRows) { return nil, nil } diff --git a/core/chains/evm/log/broadcaster.go b/core/chains/evm/log/broadcaster.go index a96474c0f78..148c36148c2 100644 --- a/core/chains/evm/log/broadcaster.go +++ b/core/chains/evm/log/broadcaster.go @@ -9,14 +9,13 @@ import ( "sync/atomic" "time" - "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" pkgerrors "github.com/pkg/errors" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" @@ -60,12 +59,10 @@ type ( Register(listener Listener, opts ListenerOpts) (unsubscribe func()) WasAlreadyConsumed(ctx context.Context, lb Broadcast) (bool, error) - MarkConsumed(ctx context.Context, lb Broadcast) error - - // MarkManyConsumed marks all the provided log broadcasts as consumed. - MarkManyConsumed(ctx context.Context, lbs []Broadcast) error + // ds is optional + MarkConsumed(ctx context.Context, ds sqlutil.DataSource, lb Broadcast) error - // NOTE: WasAlreadyConsumed, MarkConsumed and MarkManyConsumed MUST be used within a single goroutine in order for WasAlreadyConsumed to be accurate + // NOTE: WasAlreadyConsumed, and MarkConsumed MUST be used within a single goroutine in order for WasAlreadyConsumed to be accurate } BroadcasterInTest interface { @@ -422,12 +419,15 @@ func (b *broadcaster) eventLoop(chRawLogs <-chan types.Log, chErr <-chan error) debounceResubscribe := time.NewTicker(1 * time.Second) defer debounceResubscribe.Stop() + ctx, cancel := b.chStop.NewCtx() + defer cancel() + b.logger.Debug("Starting the event loop") for { // Replay requests take priority. select { case req := <-b.replayChannel: - b.onReplayRequest(req) + b.onReplayRequest(ctx, req) return true, nil default: } @@ -456,7 +456,7 @@ func (b *broadcaster) eventLoop(chRawLogs <-chan types.Log, chErr <-chan error) needsResubscribe = b.onChangeSubscriberStatus() || needsResubscribe case req := <-b.replayChannel: - b.onReplayRequest(req) + b.onReplayRequest(ctx, req) return true, nil case <-debounceResubscribe.C: @@ -480,7 +480,7 @@ func (b *broadcaster) eventLoop(chRawLogs <-chan types.Log, chErr <-chan error) } // onReplayRequest clears the pool and sets the block backfill number. -func (b *broadcaster) onReplayRequest(replayReq replayRequest) { +func (b *broadcaster) onReplayRequest(ctx context.Context, replayReq replayRequest) { // notify subscribers that we are about to replay. for subscriber := range b.registrations.registeredSubs { if subscriber.opts.ReplayStartedCallback != nil { @@ -495,11 +495,11 @@ func (b *broadcaster) onReplayRequest(replayReq replayRequest) { b.backfillBlockNumber.Int64 = replayReq.fromBlock b.backfillBlockNumber.Valid = true if replayReq.forceBroadcast { - ctx, cancel := b.chStop.CtxCancel(context.WithTimeout(context.Background(), time.Minute)) - ctx = sqlutil.WithoutDefaultTimeout(ctx) - defer cancel() // Use a longer timeout in the event that a very large amount of logs need to be marked - // as consumed. + // as unconsumed. + var cancel func() + ctx, cancel = context.WithTimeout(sqlutil.WithoutDefaultTimeout(ctx), time.Minute) + defer cancel() err := b.orm.MarkBroadcastsUnconsumed(ctx, replayReq.fromBlock) if err != nil { b.logger.Errorw("Error marking broadcasts as unconsumed", @@ -694,25 +694,12 @@ func (b *broadcaster) WasAlreadyConsumed(ctx context.Context, lb Broadcast) (boo } // MarkConsumed marks the log as having been successfully consumed by the subscriber -func (b *broadcaster) MarkConsumed(ctx context.Context, lb Broadcast) error { - return b.orm.MarkBroadcastConsumed(ctx, lb.RawLog().BlockHash, lb.RawLog().BlockNumber, lb.RawLog().Index, lb.JobID()) -} - -// MarkManyConsumed marks the logs as having been successfully consumed by the subscriber -func (b *broadcaster) MarkManyConsumed(ctx context.Context, lbs []Broadcast) (err error) { - var ( - blockHashes = make([]common.Hash, len(lbs)) - blockNumbers = make([]uint64, len(lbs)) - logIndexes = make([]uint, len(lbs)) - jobIDs = make([]int32, len(lbs)) - ) - for i := range lbs { - blockHashes[i] = lbs[i].RawLog().BlockHash - blockNumbers[i] = lbs[i].RawLog().BlockNumber - logIndexes[i] = lbs[i].RawLog().Index - jobIDs[i] = lbs[i].JobID() +func (b *broadcaster) MarkConsumed(ctx context.Context, ds sqlutil.DataSource, lb Broadcast) error { + orm := b.orm + if ds != nil { + orm = orm.WithDataSource(ds) } - return b.orm.MarkBroadcastsConsumed(ctx, blockHashes, blockNumbers, logIndexes, jobIDs) + return orm.MarkBroadcastConsumed(ctx, lb.RawLog().BlockHash, lb.RawLog().BlockNumber, lb.RawLog().Index, lb.JobID()) } // test only @@ -779,10 +766,7 @@ func (n *NullBroadcaster) TrackedAddressesCount() uint32 { func (n *NullBroadcaster) WasAlreadyConsumed(ctx context.Context, lb Broadcast) (bool, error) { return false, pkgerrors.New(n.ErrMsg) } -func (n *NullBroadcaster) MarkConsumed(ctx context.Context, lb Broadcast) error { - return pkgerrors.New(n.ErrMsg) -} -func (n *NullBroadcaster) MarkManyConsumed(ctx context.Context, lbs []Broadcast) error { +func (n *NullBroadcaster) MarkConsumed(ctx context.Context, ds sqlutil.DataSource, lb Broadcast) error { return pkgerrors.New(n.ErrMsg) } diff --git a/core/chains/evm/log/helpers_test.go b/core/chains/evm/log/helpers_test.go index 18f396fab9d..85c2fe783bb 100644 --- a/core/chains/evm/log/helpers_test.go +++ b/core/chains/evm/log/helpers_test.go @@ -281,7 +281,7 @@ func (listener *simpleLogListener) SkipMarkingConsumed(skip bool) { listener.skipMarkingConsumed.Store(skip) } -func (listener *simpleLogListener) HandleLog(lb log.Broadcast) { +func (listener *simpleLogListener) HandleLog(ctx context.Context, lb log.Broadcast) { listener.received.Lock() defer listener.received.Unlock() listener.lggr.Tracef("Listener %v HandleLog for block %v %v received at %v %v", listener.name, lb.RawLog().BlockNumber, lb.RawLog().BlockHash, lb.LatestBlockNumber(), lb.LatestBlockHash()) diff --git a/core/chains/evm/log/mocks/broadcaster.go b/core/chains/evm/log/mocks/broadcaster.go index 26fe1a35101..e5164b56611 100644 --- a/core/chains/evm/log/mocks/broadcaster.go +++ b/core/chains/evm/log/mocks/broadcaster.go @@ -8,6 +8,8 @@ import ( log "github.com/smartcontractkit/chainlink/v2/core/chains/evm/log" mock "github.com/stretchr/testify/mock" + sqlutil "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" + types "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" ) @@ -102,35 +104,17 @@ func (_m *Broadcaster) IsConnected() bool { return r0 } -// MarkConsumed provides a mock function with given fields: ctx, lb -func (_m *Broadcaster) MarkConsumed(ctx context.Context, lb log.Broadcast) error { - ret := _m.Called(ctx, lb) +// MarkConsumed provides a mock function with given fields: ctx, ds, lb +func (_m *Broadcaster) MarkConsumed(ctx context.Context, ds sqlutil.DataSource, lb log.Broadcast) error { + ret := _m.Called(ctx, ds, lb) if len(ret) == 0 { panic("no return value specified for MarkConsumed") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, log.Broadcast) error); ok { - r0 = rf(ctx, lb) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// MarkManyConsumed provides a mock function with given fields: ctx, lbs -func (_m *Broadcaster) MarkManyConsumed(ctx context.Context, lbs []log.Broadcast) error { - ret := _m.Called(ctx, lbs) - - if len(ret) == 0 { - panic("no return value specified for MarkManyConsumed") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, []log.Broadcast) error); ok { - r0 = rf(ctx, lbs) + if rf, ok := ret.Get(0).(func(context.Context, sqlutil.DataSource, log.Broadcast) error); ok { + r0 = rf(ctx, ds, lb) } else { r0 = ret.Error(0) } diff --git a/core/chains/evm/log/orm.go b/core/chains/evm/log/orm.go index 71c9675d6fd..6e94d3bf8a8 100644 --- a/core/chains/evm/log/orm.go +++ b/core/chains/evm/log/orm.go @@ -3,16 +3,13 @@ package log import ( "context" "database/sql" - "fmt" "math/big" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" - "github.com/jmoiron/sqlx" pkgerrors "github.com/pkg/errors" "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" - "github.com/smartcontractkit/chainlink-common/pkg/utils" ubig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" ) @@ -31,8 +28,6 @@ type ORM interface { WasBroadcastConsumed(ctx context.Context, blockHash common.Hash, logIndex uint, jobID int32) (bool, error) // MarkBroadcastConsumed marks the log broadcast as consumed by jobID. MarkBroadcastConsumed(ctx context.Context, blockHash common.Hash, blockNumber uint64, logIndex uint, jobID int32) error - // MarkBroadcastsConsumed marks the log broadcasts as consumed by jobID. - MarkBroadcastsConsumed(ctx context.Context, blockHashes []common.Hash, blockNumbers []uint64, logIndexes []uint, jobIDs []int32) error // MarkBroadcastsUnconsumed marks all log broadcasts from all jobs on or after fromBlock as // unconsumed. MarkBroadcastsUnconsumed(ctx context.Context, fromBlock int64) error @@ -45,20 +40,23 @@ type ORM interface { // Reinitialize cleans up the database by removing any unconsumed broadcasts, then updating (if necessary) and // returning the pending minimum block number. Reinitialize(ctx context.Context) (blockNumber *int64, err error) + + WithDataSource(sqlutil.DataSource) ORM } type orm struct { - db sqlutil.DataSource + ds sqlutil.DataSource evmChainID ubig.Big } var _ ORM = (*orm)(nil) -func NewORM(db sqlutil.DataSource, evmChainID big.Int) *orm { - return &orm{ - db: db, - evmChainID: *ubig.New(&evmChainID), - } +func NewORM(ds sqlutil.DataSource, evmChainID big.Int) *orm { + return &orm{ds, *ubig.New(&evmChainID)} +} + +func (o *orm) WithDataSource(ds sqlutil.DataSource) ORM { + return &orm{ds, o.evmChainID} } func (o *orm) WasBroadcastConsumed(ctx context.Context, blockHash common.Hash, logIndex uint, jobID int32) (consumed bool, err error) { @@ -75,7 +73,7 @@ func (o *orm) WasBroadcastConsumed(ctx context.Context, blockHash common.Hash, l jobID, o.evmChainID, } - err = o.db.GetContext(ctx, &consumed, query, args...) + err = o.ds.GetContext(ctx, &consumed, query, args...) if pkgerrors.Is(err, sql.ErrNoRows) { return false, nil } @@ -90,7 +88,7 @@ func (o *orm) FindBroadcasts(ctx context.Context, fromBlockNum int64, toBlockNum AND block_number <= $2 AND evm_chain_id = $3 ` - err := o.db.SelectContext(ctx, &broadcasts, query, fromBlockNum, toBlockNum, o.evmChainID) + err := o.ds.SelectContext(ctx, &broadcasts, query, fromBlockNum, toBlockNum, o.evmChainID) if err != nil { return nil, pkgerrors.Wrap(err, "failed to find log broadcasts") } @@ -98,7 +96,7 @@ func (o *orm) FindBroadcasts(ctx context.Context, fromBlockNum int64, toBlockNum } func (o *orm) CreateBroadcast(ctx context.Context, blockHash common.Hash, blockNumber uint64, logIndex uint, jobID int32) error { - _, err := o.db.ExecContext(ctx, ` + _, err := o.ds.ExecContext(ctx, ` INSERT INTO log_broadcasts (block_hash, block_number, log_index, job_id, created_at, updated_at, consumed, evm_chain_id) VALUES ($1, $2, $3, $4, NOW(), NOW(), false, $5) `, blockHash, blockNumber, logIndex, jobID, o.evmChainID) @@ -106,7 +104,7 @@ func (o *orm) CreateBroadcast(ctx context.Context, blockHash common.Hash, blockN } func (o *orm) MarkBroadcastConsumed(ctx context.Context, blockHash common.Hash, blockNumber uint64, logIndex uint, jobID int32) error { - _, err := o.db.ExecContext(ctx, ` + _, err := o.ds.ExecContext(ctx, ` INSERT INTO log_broadcasts (block_hash, block_number, log_index, job_id, created_at, updated_at, consumed, evm_chain_id) VALUES ($1, $2, $3, $4, NOW(), NOW(), true, $5) ON CONFLICT (job_id, block_hash, log_index, evm_chain_id) DO UPDATE @@ -115,45 +113,9 @@ func (o *orm) MarkBroadcastConsumed(ctx context.Context, blockHash common.Hash, return pkgerrors.Wrap(err, "failed to mark log broadcast as consumed") } -// MarkBroadcastsConsumed marks many broadcasts as consumed. -// The lengths of all the provided slices must be equal, otherwise an error is returned. -func (o *orm) MarkBroadcastsConsumed(ctx context.Context, blockHashes []common.Hash, blockNumbers []uint64, logIndexes []uint, jobIDs []int32) error { - if !utils.AllEqual(len(blockHashes), len(blockNumbers), len(logIndexes), len(jobIDs)) { - return fmt.Errorf("all arg slice lengths must be equal, got: %d %d %d %d", - len(blockHashes), len(blockNumbers), len(logIndexes), len(jobIDs), - ) - } - - type input struct { - BlockHash common.Hash `db:"blockHash"` - BlockNumber uint64 `db:"blockNumber"` - LogIndex uint `db:"logIndex"` - JobID int32 `db:"jobID"` - ChainID ubig.Big `db:"chainID"` - } - inputs := make([]input, len(blockHashes)) - query := ` -INSERT INTO log_broadcasts (block_hash, block_number, log_index, job_id, created_at, updated_at, consumed, evm_chain_id) -VALUES (:blockHash, :blockNumber, :logIndex, :jobID, NOW(), NOW(), true, :chainID) -ON CONFLICT (job_id, block_hash, log_index, evm_chain_id) DO UPDATE -SET consumed = true, updated_at = NOW(); - ` - for i := range blockHashes { - inputs[i] = input{ - BlockHash: blockHashes[i], - BlockNumber: blockNumbers[i], - LogIndex: logIndexes[i], - JobID: jobIDs[i], - ChainID: o.evmChainID, - } - } - _, err := o.db.(*sqlx.DB).NamedExecContext(ctx, query, inputs) - return pkgerrors.Wrap(err, "mark broadcasts consumed") -} - // MarkBroadcastsUnconsumed implements the ORM interface. func (o *orm) MarkBroadcastsUnconsumed(ctx context.Context, fromBlock int64) error { - _, err := o.db.ExecContext(ctx, ` + _, err := o.ds.ExecContext(ctx, ` UPDATE log_broadcasts SET consumed = false WHERE block_number >= $1 @@ -193,7 +155,7 @@ func (o *orm) Reinitialize(ctx context.Context) (*int64, error) { } func (o *orm) SetPendingMinBlock(ctx context.Context, blockNumber *int64) error { - _, err := o.db.ExecContext(ctx, ` + _, err := o.ds.ExecContext(ctx, ` INSERT INTO log_broadcasts_pending (evm_chain_id, block_number, created_at, updated_at) VALUES ($1, $2, NOW(), NOW()) ON CONFLICT (evm_chain_id) DO UPDATE SET block_number = $3, updated_at = NOW() `, o.evmChainID, blockNumber, blockNumber) @@ -202,7 +164,7 @@ func (o *orm) SetPendingMinBlock(ctx context.Context, blockNumber *int64) error func (o *orm) GetPendingMinBlock(ctx context.Context) (*int64, error) { var blockNumber *int64 - err := o.db.GetContext(ctx, &blockNumber, ` + err := o.ds.GetContext(ctx, &blockNumber, ` SELECT block_number FROM log_broadcasts_pending WHERE evm_chain_id = $1 `, o.evmChainID) if pkgerrors.Is(err, sql.ErrNoRows) { @@ -215,7 +177,7 @@ func (o *orm) GetPendingMinBlock(ctx context.Context) (*int64, error) { func (o *orm) getUnconsumedMinBlock(ctx context.Context) (*int64, error) { var blockNumber *int64 - err := o.db.GetContext(ctx, &blockNumber, ` + err := o.ds.GetContext(ctx, &blockNumber, ` SELECT min(block_number) FROM log_broadcasts WHERE evm_chain_id = $1 AND consumed = false @@ -230,7 +192,7 @@ func (o *orm) getUnconsumedMinBlock(ctx context.Context) (*int64, error) { } func (o *orm) removeUnconsumed(ctx context.Context) error { - _, err := o.db.ExecContext(ctx, ` + _, err := o.ds.ExecContext(ctx, ` DELETE FROM log_broadcasts WHERE evm_chain_id = $1 AND consumed = false diff --git a/core/chains/evm/log/orm_test.go b/core/chains/evm/log/orm_test.go index ba9509d4518..1a6d927cd50 100644 --- a/core/chains/evm/log/orm_test.go +++ b/core/chains/evm/log/orm_test.go @@ -21,7 +21,6 @@ func TestORM_broadcasts(t *testing.T) { db := pgtest.NewSqlxDB(t) cfg := configtest.NewGeneralConfig(t, nil) ethKeyStore := cltest.NewKeyStore(t, db, cfg.Database()).Eth() - ctx := testutils.Context(t) orm := log.NewORM(db, cltest.FixtureChainID) @@ -44,12 +43,12 @@ func TestORM_broadcasts(t *testing.T) { require.Zero(t, rowsAffected) t.Run("WasBroadcastConsumed_DNE", func(t *testing.T) { - _, err := orm.WasBroadcastConsumed(ctx, rawLog.BlockHash, rawLog.Index, listener.JobID()) + _, err := orm.WasBroadcastConsumed(testutils.Context(t), rawLog.BlockHash, rawLog.Index, listener.JobID()) require.NoError(t, err) }) require.True(t, t.Run("CreateBroadcast", func(t *testing.T) { - err := orm.CreateBroadcast(ctx, rawLog.BlockHash, rawLog.BlockNumber, rawLog.Index, listener.JobID()) + err := orm.CreateBroadcast(testutils.Context(t), rawLog.BlockHash, rawLog.BlockNumber, rawLog.Index, listener.JobID()) require.NoError(t, err) var consumed null.Bool @@ -59,13 +58,13 @@ func TestORM_broadcasts(t *testing.T) { })) t.Run("WasBroadcastConsumed_false", func(t *testing.T) { - was, err := orm.WasBroadcastConsumed(ctx, rawLog.BlockHash, rawLog.Index, listener.JobID()) + was, err := orm.WasBroadcastConsumed(testutils.Context(t), rawLog.BlockHash, rawLog.Index, listener.JobID()) require.NoError(t, err) require.False(t, was) }) require.True(t, t.Run("MarkBroadcastConsumed", func(t *testing.T) { - err := orm.MarkBroadcastConsumed(ctx, rawLog.BlockHash, rawLog.BlockNumber, rawLog.Index, listener.JobID()) + err := orm.MarkBroadcastConsumed(testutils.Context(t), rawLog.BlockHash, rawLog.BlockNumber, rawLog.Index, listener.JobID()) require.NoError(t, err) var consumed null.Bool @@ -74,66 +73,17 @@ func TestORM_broadcasts(t *testing.T) { require.Equal(t, null.BoolFrom(true), consumed) })) - t.Run("MarkBroadcastsConsumed Success", func(t *testing.T) { - var ( - err error - blockHashes []common.Hash - blockNumbers []uint64 - logIndexes []uint - jobIDs []int32 - ) - for i := 0; i < 3; i++ { - l := cltest.RandomLog(t) - err = orm.CreateBroadcast(ctx, l.BlockHash, l.BlockNumber, l.Index, listener.JobID()) - require.NoError(t, err) - blockHashes = append(blockHashes, l.BlockHash) - blockNumbers = append(blockNumbers, l.BlockNumber) - logIndexes = append(logIndexes, l.Index) - jobIDs = append(jobIDs, listener.JobID()) - - } - err = orm.MarkBroadcastsConsumed(ctx, blockHashes, blockNumbers, logIndexes, jobIDs) - require.NoError(t, err) - - for i := range blockHashes { - was, err := orm.WasBroadcastConsumed(ctx, blockHashes[i], logIndexes[i], jobIDs[i]) - require.NoError(t, err) - require.True(t, was) - } - }) - - t.Run("MarkBroadcastsConsumed Failure", func(t *testing.T) { - var ( - err error - blockHashes []common.Hash - blockNumbers []uint64 - logIndexes []uint - jobIDs []int32 - ) - for i := 0; i < 5; i++ { - l := cltest.RandomLog(t) - err = orm.CreateBroadcast(ctx, l.BlockHash, l.BlockNumber, l.Index, listener.JobID()) - require.NoError(t, err) - blockHashes = append(blockHashes, l.BlockHash) - blockNumbers = append(blockNumbers, l.BlockNumber) - logIndexes = append(logIndexes, l.Index) - jobIDs = append(jobIDs, listener.JobID()) - } - err = orm.MarkBroadcastsConsumed(ctx, blockHashes[:len(blockHashes)-2], blockNumbers, logIndexes, jobIDs) - require.Error(t, err) - }) - t.Run("WasBroadcastConsumed_true", func(t *testing.T) { - was, err := orm.WasBroadcastConsumed(ctx, rawLog.BlockHash, rawLog.Index, listener.JobID()) + was, err := orm.WasBroadcastConsumed(testutils.Context(t), rawLog.BlockHash, rawLog.Index, listener.JobID()) require.NoError(t, err) require.True(t, was) }) } func TestORM_pending(t *testing.T) { + ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) orm := log.NewORM(db, cltest.FixtureChainID) - ctx := testutils.Context(t) num, err := orm.GetPendingMinBlock(ctx) require.NoError(t, err) @@ -156,9 +106,9 @@ func TestORM_pending(t *testing.T) { } func TestORM_MarkUnconsumed(t *testing.T) { + ctx := testutils.Context(t) db := pgtest.NewSqlxDB(t) cfg := configtest.NewGeneralConfig(t, nil) - ctx := testutils.Context(t) ethKeyStore := cltest.NewKeyStore(t, db, cfg.Database()).Eth() orm := log.NewORM(db, cltest.FixtureChainID) @@ -256,8 +206,8 @@ func TestORM_Reinitialize(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { db := pgtest.NewSqlxDB(t) - orm := log.NewORM(db, cltest.FixtureChainID) ctx := testutils.Context(t) + orm := log.NewORM(db, cltest.FixtureChainID) jobID := cltest.MustInsertV2JobSpec(t, db, common.BigToAddress(big.NewInt(rand.Int63()))).ID diff --git a/core/chains/evm/log/registrations.go b/core/chains/evm/log/registrations.go index b56d3f4aaaa..c82fee43b6e 100644 --- a/core/chains/evm/log/registrations.go +++ b/core/chains/evm/log/registrations.go @@ -62,7 +62,7 @@ type ( // The Listener responds to log events through HandleLog. Listener interface { - HandleLog(b Broadcast) + HandleLog(ctx context.Context, b Broadcast) JobID() int32 } @@ -240,6 +240,9 @@ func (r *registrations) sendLogs(ctx context.Context, logsToSend []logsOnBlock, for _, log := range logsPerBlock.Logs { handlers.sendLog(ctx, log, latestHead, broadcastsExisting, bc, r.logger) + if ctx.Err() != nil { + return + } } } } @@ -442,7 +445,7 @@ func (r *handler) sendLog(ctx context.Context, log types.Log, latestHead evmtype wg.Add(1) go func() { defer wg.Done() - handleLog(&broadcast{ + handleLog(ctx, &broadcast{ latestBlockNumber, latestHead.Hash, latestHead.ReceiptsRoot, diff --git a/core/chains/evm/log/registrations_test.go b/core/chains/evm/log/registrations_test.go index 2be01dca2bf..8c0beaa9379 100644 --- a/core/chains/evm/log/registrations_test.go +++ b/core/chains/evm/log/registrations_test.go @@ -1,6 +1,7 @@ package log import ( + "context" "testing" "github.com/ethereum/go-ethereum/common" @@ -18,8 +19,8 @@ type testListener struct { jobID int32 } -func (tl testListener) JobID() int32 { return tl.jobID } -func (tl testListener) HandleLog(Broadcast) { panic("not implemented") } +func (tl testListener) JobID() int32 { return tl.jobID } +func (tl testListener) HandleLog(context.Context, Broadcast) { panic("not implemented") } func newTestListener(t *testing.T, jobID int32) testListener { return testListener{jobID} diff --git a/core/chains/evm/txmgr/broadcaster_test.go b/core/chains/evm/txmgr/broadcaster_test.go index 1e8f1c73b34..3500002e8da 100644 --- a/core/chains/evm/txmgr/broadcaster_test.go +++ b/core/chains/evm/txmgr/broadcaster_test.go @@ -1113,7 +1113,7 @@ func TestEthBroadcaster_ProcessUnstartedEthTxs_Errors(t *testing.T) { t.Run("with erroring callback bails out", func(t *testing.T) { require.NoError(t, txStore.InsertTx(ctx, &etx)) - fn := func(id uuid.UUID, result interface{}, err error) error { + fn := func(ctx context.Context, id uuid.UUID, result interface{}, err error) error { return errors.New("something exploded in the callback") } @@ -1130,7 +1130,7 @@ func TestEthBroadcaster_ProcessUnstartedEthTxs_Errors(t *testing.T) { }) t.Run("calls resume with error", func(t *testing.T) { - fn := func(id uuid.UUID, result interface{}, err error) error { + fn := func(ctx context.Context, id uuid.UUID, result interface{}, err error) error { require.Equal(t, id, tr.ID) require.Nil(t, result) require.Error(t, err) diff --git a/core/chains/evm/txmgr/confirmer_test.go b/core/chains/evm/txmgr/confirmer_test.go index 80868d448e0..357dafcbdc4 100644 --- a/core/chains/evm/txmgr/confirmer_test.go +++ b/core/chains/evm/txmgr/confirmer_test.go @@ -1,6 +1,7 @@ package txmgr_test import ( + "context" "encoding/json" "errors" "fmt" @@ -2966,7 +2967,7 @@ func TestEthConfirmer_ResumePendingRuns(t *testing.T) { pgtest.MustExec(t, db, `SET CONSTRAINTS pipeline_runs_pipeline_spec_id_fkey DEFERRED`) t.Run("doesn't process task runs that are not suspended (possibly already previously resumed)", func(t *testing.T) { - ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(uuid.UUID, interface{}, error) error { + ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(context.Context, uuid.UUID, interface{}, error) error { t.Fatal("No value expected") return nil }) @@ -2985,7 +2986,7 @@ func TestEthConfirmer_ResumePendingRuns(t *testing.T) { }) t.Run("doesn't process task runs where the receipt is younger than minConfirmations", func(t *testing.T) { - ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(uuid.UUID, interface{}, error) error { + ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(context.Context, uuid.UUID, interface{}, error) error { t.Fatal("No value expected") return nil }) @@ -3006,7 +3007,7 @@ func TestEthConfirmer_ResumePendingRuns(t *testing.T) { ch := make(chan interface{}) nonce := evmtypes.Nonce(3) var err error - ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(id uuid.UUID, value interface{}, thisErr error) error { + ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(ctx context.Context, id uuid.UUID, value interface{}, thisErr error) error { err = thisErr ch <- value return nil @@ -3059,7 +3060,7 @@ func TestEthConfirmer_ResumePendingRuns(t *testing.T) { } ch := make(chan data) nonce := evmtypes.Nonce(4) - ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(id uuid.UUID, value interface{}, err error) error { + ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(ctx context.Context, id uuid.UUID, value interface{}, err error) error { ch <- data{value, err} return nil }) @@ -3106,7 +3107,7 @@ func TestEthConfirmer_ResumePendingRuns(t *testing.T) { t.Run("does not mark callback complete if callback fails", func(t *testing.T) { nonce := evmtypes.Nonce(5) - ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(uuid.UUID, interface{}, error) error { + ec := newEthConfirmer(t, txStore, ethClient, evmcfg, ethKeyStore, func(context.Context, uuid.UUID, interface{}, error) error { return errors.New("error") }) diff --git a/core/internal/cltest/cltest.go b/core/internal/cltest/cltest.go index 3a92269cc03..ba182d60515 100644 --- a/core/internal/cltest/cltest.go +++ b/core/internal/cltest/cltest.go @@ -182,7 +182,7 @@ type JobPipelineConfig interface { func NewJobPipelineV2(t testing.TB, cfg pipeline.BridgeConfig, jpcfg JobPipelineConfig, dbCfg pg.QConfig, legacyChains legacyevm.LegacyChainContainer, db *sqlx.DB, keyStore keystore.Master, restrictedHTTPClient, unrestrictedHTTPClient *http.Client) JobPipelineV2TestHelper { lggr := logger.TestLogger(t) - prm := pipeline.NewORM(db, lggr, dbCfg, jpcfg.MaxSuccessfulRuns()) + prm := pipeline.NewORM(db, lggr, jpcfg.MaxSuccessfulRuns()) btORM := bridges.NewORM(db) jrm := job.NewORM(db, prm, btORM, keyStore, lggr, dbCfg) pr := pipeline.NewRunner(prm, btORM, jpcfg, cfg, legacyChains, keyStore.Eth(), keyStore.VRF(), lggr, restrictedHTTPClient, unrestrictedHTTPClient) diff --git a/core/internal/cltest/factories.go b/core/internal/cltest/factories.go index d7e1036bcac..43cf902ca8a 100644 --- a/core/internal/cltest/factories.go +++ b/core/internal/cltest/factories.go @@ -402,7 +402,7 @@ func MustInsertKeeperJob(t *testing.T, db *sqlx.DB, korm *keeper.ORM, from evmty cfg := configtest.NewTestGeneralConfig(t) tlg := logger.TestLogger(t) - prm := pipeline.NewORM(db, tlg, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + prm := pipeline.NewORM(db, tlg, cfg.JobPipeline().MaxSuccessfulRuns()) btORM := bridges.NewORM(db) jrm := job.NewORM(db, prm, btORM, nil, tlg, cfg.Database()) err = jrm.InsertJob(&jb) diff --git a/core/internal/cltest/job_factories.go b/core/internal/cltest/job_factories.go index 5d8f75e36c3..2b527fbc29c 100644 --- a/core/internal/cltest/job_factories.go +++ b/core/internal/cltest/job_factories.go @@ -11,6 +11,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/bridges" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" @@ -43,12 +44,13 @@ func MinimalOCRNonBootstrapSpec(contractAddress, transmitterAddress types.EIP55A } func MustInsertWebhookSpec(t *testing.T, db *sqlx.DB) (job.Job, job.WebhookSpec) { + ctx := testutils.Context(t) jobORM, pipelineORM := getORMs(t, db) webhookSpec := job.WebhookSpec{} require.NoError(t, jobORM.InsertWebhookSpec(&webhookSpec)) pSpec := pipeline.Pipeline{} - pipelineSpecID, err := pipelineORM.CreateSpec(pSpec, 0) + pipelineSpecID, err := pipelineORM.CreateSpec(ctx, nil, pSpec, 0) require.NoError(t, err) createdJob := job.Job{WebhookSpecID: &webhookSpec.ID, WebhookSpec: &webhookSpec, SchemaVersion: 1, Type: "webhook", @@ -62,7 +64,7 @@ func getORMs(t *testing.T, db *sqlx.DB) (jobORM job.ORM, pipelineORM pipeline.OR config := configtest.NewTestGeneralConfig(t) keyStore := NewKeyStore(t, db, config.Database()) lggr := logger.TestLogger(t) - pipelineORM = pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM = pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgeORM := bridges.NewORM(db) jobORM = job.NewORM(db, pipelineORM, bridgeORM, keyStore, lggr, config.Database()) t.Cleanup(func() { jobORM.Close() }) diff --git a/core/internal/features/features_test.go b/core/internal/features/features_test.go index 2c40c848263..4afad453110 100644 --- a/core/internal/features/features_test.go +++ b/core/internal/features/features_test.go @@ -236,7 +236,7 @@ observationSource = """ _ = cltest.CreateJobRunViaExternalInitiatorV2(t, app, jobUUID, *eia, cltest.MustJSONMarshal(t, eiRequest)) - pipelineORM := pipeline.NewORM(app.GetSqlxDB(), logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(app.GetSqlxDB(), logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) bridgeORM := bridges.NewORM(app.GetSqlxDB()) jobORM := job.NewORM(app.GetSqlxDB(), pipelineORM, bridgeORM, app.KeyStore, logger.TestLogger(t), cfg.Database()) diff --git a/core/services/blockhashstore/delegate.go b/core/services/blockhashstore/delegate.go index 9a11c057c32..6bcfc26ddb6 100644 --- a/core/services/blockhashstore/delegate.go +++ b/core/services/blockhashstore/delegate.go @@ -19,7 +19,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/utils" ) @@ -194,7 +193,7 @@ func (d *Delegate) BeforeJobCreated(spec job.Job) {} func (d *Delegate) BeforeJobDeleted(spec job.Job) {} // OnDeleteJob satisfies the job.Delegate interface. -func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // service is a job.Service that runs the BHS feeder every pollPeriod. type service struct { diff --git a/core/services/blockheaderfeeder/delegate.go b/core/services/blockheaderfeeder/delegate.go index 19edb43bc23..07cab534af7 100644 --- a/core/services/blockheaderfeeder/delegate.go +++ b/core/services/blockheaderfeeder/delegate.go @@ -19,7 +19,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/blockhashstore" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/utils" ) @@ -208,7 +207,7 @@ func (d *Delegate) BeforeJobCreated(spec job.Job) {} func (d *Delegate) BeforeJobDeleted(spec job.Job) {} // OnDeleteJob satisfies the job.Delegate interface. -func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // service is a job.Service that runs the BHS feeder every pollPeriod. type service struct { diff --git a/core/services/chainlink/application.go b/core/services/chainlink/application.go index 832bea523b5..8542074c27c 100644 --- a/core/services/chainlink/application.go +++ b/core/services/chainlink/application.go @@ -308,7 +308,7 @@ func NewApplication(opts ApplicationOpts) (Application, error) { } var ( - pipelineORM = pipeline.NewORM(sqlxDB, globalLogger, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + pipelineORM = pipeline.NewORM(sqlxDB, globalLogger, cfg.JobPipeline().MaxSuccessfulRuns()) bridgeORM = bridges.NewORM(sqlxDB) mercuryORM = mercury.NewORM(opts.DB) pipelineRunner = pipeline.NewRunner(pipelineORM, bridgeORM, cfg.JobPipeline(), cfg.WebServer(), legacyEVMChains, keyStore.Eth(), keyStore.VRF(), globalLogger, restrictedHTTPClient, unrestrictedHTTPClient) @@ -346,7 +346,6 @@ func NewApplication(opts ApplicationOpts) (Application, error) { pipelineORM, legacyEVMChains, globalLogger, - cfg.Database(), mailMon), job.Webhook: webhook.NewDelegate( pipelineRunner, @@ -829,7 +828,7 @@ func (app *ChainlinkApplication) ResumeJobV2( taskID uuid.UUID, result pipeline.Result, ) error { - return app.pipelineRunner.ResumeRun(taskID, result.Value, result.Error) + return app.pipelineRunner.ResumeRun(ctx, taskID, result.Value, result.Error) } func (app *ChainlinkApplication) GetFeedsService() feeds.Service { diff --git a/core/services/cron/cron_test.go b/core/services/cron/cron_test.go index 3ace0f3ceae..c3ecc0957c7 100644 --- a/core/services/cron/cron_test.go +++ b/core/services/cron/cron_test.go @@ -27,7 +27,7 @@ func TestCronV2Pipeline(t *testing.T) { keyStore := cltest.NewKeyStore(t, db, cfg.Database()) lggr := logger.TestLogger(t) - orm := pipeline.NewORM(db, lggr, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + orm := pipeline.NewORM(db, lggr, cfg.JobPipeline().MaxSuccessfulRuns()) btORM := bridges.NewORM(db) jobORM := job.NewORM(db, orm, btORM, keyStore, lggr, cfg.Database()) diff --git a/core/services/cron/delegate.go b/core/services/cron/delegate.go index 05b5b36c00f..d8a1390103e 100644 --- a/core/services/cron/delegate.go +++ b/core/services/cron/delegate.go @@ -7,7 +7,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" ) @@ -29,10 +28,10 @@ func (d *Delegate) JobType() job.Type { return job.Cron } -func (d *Delegate) BeforeJobCreated(spec job.Job) {} -func (d *Delegate) AfterJobCreated(spec job.Job) {} -func (d *Delegate) BeforeJobDeleted(spec job.Job) {} -func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) BeforeJobCreated(spec job.Job) {} +func (d *Delegate) AfterJobCreated(spec job.Job) {} +func (d *Delegate) BeforeJobDeleted(spec job.Job) {} +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // ServicesForSpec returns the scheduler to be used for running cron jobs func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) (services []job.ServiceCtx, err error) { diff --git a/core/services/directrequest/delegate.go b/core/services/directrequest/delegate.go index d6afc215fb9..33a0a7e73da 100644 --- a/core/services/directrequest/delegate.go +++ b/core/services/directrequest/delegate.go @@ -11,6 +11,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/assets" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/log" @@ -19,7 +20,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated/operator_wrapper" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/store/models" ) @@ -63,10 +63,10 @@ func (d *Delegate) JobType() job.Type { return job.DirectRequest } -func (d *Delegate) BeforeJobCreated(spec job.Job) {} -func (d *Delegate) AfterJobCreated(spec job.Job) {} -func (d *Delegate) BeforeJobDeleted(spec job.Job) {} -func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) BeforeJobCreated(spec job.Job) {} +func (d *Delegate) AfterJobCreated(spec job.Job) {} +func (d *Delegate) BeforeJobDeleted(spec job.Job) {} +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // ServicesForSpec returns the log listener service for a direct request job func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.ServiceCtx, error) { @@ -191,7 +191,7 @@ func (l *listener) Close() error { }) } -func (l *listener) HandleLog(lb log.Broadcast) { +func (l *listener) HandleLog(ctx context.Context, lb log.Broadcast) { log := lb.DecodedLog() if log == nil || reflect.ValueOf(log).IsNil() { l.logger.Error("HandleLog: ignoring nil value") @@ -374,7 +374,7 @@ func (l *listener) handleOracleRequest(ctx context.Context, request *operator_wr }, }) run := pipeline.NewRun(*l.job.PipelineSpec, vars) - _, err := l.pipelineRunner.Run(ctx, run, l.logger, true, func(tx pg.Queryer) error { + _, err := l.pipelineRunner.Run(ctx, run, l.logger, true, func(tx sqlutil.DataSource) error { l.markLogConsumed(ctx, lb) return nil }) @@ -407,7 +407,7 @@ func (l *listener) handleCancelOracleRequest(ctx context.Context, request *opera } func (l *listener) markLogConsumed(ctx context.Context, lb log.Broadcast) { - if err := l.logBroadcaster.MarkConsumed(ctx, lb); err != nil { + if err := l.logBroadcaster.MarkConsumed(ctx, nil, lb); err != nil { l.logger.Errorw("Unable to mark log consumed", "err", err, "log", lb.String()) } } diff --git a/core/services/directrequest/delegate_test.go b/core/services/directrequest/delegate_test.go index a7f2ba01315..0235a0c4eec 100644 --- a/core/services/directrequest/delegate_test.go +++ b/core/services/directrequest/delegate_test.go @@ -15,6 +15,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/assets" "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox/mailboxtest" "github.com/smartcontractkit/chainlink/v2/core/bridges" @@ -31,7 +32,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/chainlink" "github.com/smartcontractkit/chainlink/v2/core/services/directrequest" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" pipeline_mocks "github.com/smartcontractkit/chainlink/v2/core/services/pipeline/mocks" evmrelay "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm" @@ -88,7 +88,7 @@ func NewDirectRequestUniverseWithConfig(t *testing.T, cfg chainlink.GeneralConfi keyStore := cltest.NewKeyStore(t, db, cfg.Database()) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: cfg, Client: ethClient, LogBroadcaster: broadcaster, MailMon: mailMon, KeyStore: keyStore.Eth()}) lggr := logger.TestLogger(t) - orm := pipeline.NewORM(db, lggr, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + orm := pipeline.NewORM(db, lggr, cfg.JobPipeline().MaxSuccessfulRuns()) btORM := bridges.NewORM(db) jobORM := job.NewORM(db, orm, btORM, keyStore, lggr, cfg.Database()) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) @@ -159,28 +159,29 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { }) log.On("DecodedLog").Return(&logOracleRequest) log.On("String").Return("") - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) runBeganAwaiter := cltest.NewAwaiter() uni.runner.On("Run", mock.Anything, mock.AnythingOfType("*pipeline.Run"), mock.Anything, mock.Anything, mock.Anything). Return(false, nil). Run(func(args mock.Arguments) { runBeganAwaiter.ItHappened() - fn := args.Get(4).(func(pg.Queryer) error) + fn := args.Get(4).(func(source sqlutil.DataSource) error) require.NoError(t, fn(nil)) }).Once() - err := uni.service.Start(testutils.Context(t)) + ctx := testutils.Context(t) + err := uni.service.Start(ctx) require.NoError(t, err) require.NotNil(t, uni.listener, "listener was nil; expected broadcaster.Register to have been called") // check if the job exists under the correct ID - drJob, jErr := uni.jobORM.FindJob(testutils.Context(t), uni.listener.JobID()) + drJob, jErr := uni.jobORM.FindJob(ctx, uni.listener.JobID()) require.NoError(t, jErr) require.Equal(t, drJob.ID, uni.listener.JobID()) require.NotNil(t, drJob.DirectRequestSpec) - uni.listener.HandleLog(log) + uni.listener.HandleLog(ctx, log) runBeganAwaiter.AwaitOrFail(t, 5*time.Second) @@ -207,12 +208,13 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { log.On("DecodedLog").Return(&logOracleRequest).Maybe() log.On("String").Return("") log.On("EVMChainID").Return(*big.NewInt(0)) - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil).Maybe() + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() - err := uni.service.Start(testutils.Context(t)) + ctx := testutils.Context(t) + err := uni.service.Start(ctx) require.NoError(t, err) - uni.listener.HandleLog(log) + uni.listener.HandleLog(ctx, log) uni.logBroadcaster.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) @@ -224,7 +226,7 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { uni.runner.On("Run", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). Run(func(args mock.Arguments) { runBeganAwaiter.ItHappened() - fn := args.Get(4).(func(pg.Queryer) error) + fn := args.Get(4).(func(sqlutil.DataSource) error) require.NoError(t, fn(nil)) }).Once().Return(false, nil) @@ -241,7 +243,7 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { log := log_mocks.NewBroadcast(t) lbAwaiter := cltest.NewAwaiter() uni.logBroadcaster.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { lbAwaiter.ItHappened() }).Return(nil) + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { lbAwaiter.ItHappened() }).Return(nil) logCancelOracleRequest := operator_wrapper.OperatorCancelOracleRequest{RequestId: uni.spec.ExternalIDEncodeStringToTopic()} logAwaiter := cltest.NewAwaiter() @@ -251,10 +253,11 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { }) log.On("String").Return("") - err := uni.service.Start(testutils.Context(t)) + ctx := testutils.Context(t) + err := uni.service.Start(ctx) require.NoError(t, err) - uni.listener.HandleLog(log) + uni.listener.HandleLog(ctx, log) logAwaiter.AwaitOrFail(t) lbAwaiter.AwaitOrFail(t) @@ -279,12 +282,13 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { log.On("String").Return("") log.On("DecodedLog").Return(&logCancelOracleRequest) lbAwaiter := cltest.NewAwaiter() - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { lbAwaiter.ItHappened() }).Return(nil) + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { lbAwaiter.ItHappened() }).Return(nil) - err := uni.service.Start(testutils.Context(t)) + ctx := testutils.Context(t) + err := uni.service.Start(ctx) require.NoError(t, err) - uni.listener.HandleLog(log) + uni.listener.HandleLog(ctx, log) lbAwaiter.AwaitOrFail(t) @@ -314,7 +318,7 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { }) runLog.On("DecodedLog").Return(&logOracleRequest) runLog.On("String").Return("") - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) cancelLog := log_mocks.NewBroadcast(t) @@ -328,9 +332,10 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { }) cancelLog.On("DecodedLog").Return(&logCancelOracleRequest) cancelLog.On("String").Return("") - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) - err := uni.service.Start(testutils.Context(t)) + ctx := testutils.Context(t) + err := uni.service.Start(ctx) require.NoError(t, err) timeout := 5 * time.Second @@ -346,11 +351,11 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { runCancelledAwaiter.ItHappened() } }).Once().Return(false, nil) - uni.listener.HandleLog(runLog) + uni.listener.HandleLog(ctx, runLog) runBeganAwaiter.AwaitOrFail(t, timeout) - uni.listener.HandleLog(cancelLog) + uni.listener.HandleLog(ctx, cancelLog) runCancelledAwaiter.AwaitOrFail(t, timeout) @@ -384,25 +389,26 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { }) log.On("DecodedLog").Return(&logOracleRequest) log.On("String").Return("") - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) runBeganAwaiter := cltest.NewAwaiter() uni.runner.On("Run", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { runBeganAwaiter.ItHappened() - fn := args.Get(4).(func(pg.Queryer) error) + fn := args.Get(4).(func(sqlutil.DataSource) error) require.NoError(t, fn(nil)) }).Once().Return(false, nil) - err := uni.service.Start(testutils.Context(t)) + ctx := testutils.Context(t) + err := uni.service.Start(ctx) require.NoError(t, err) // check if the job exists under the correct ID - drJob, jErr := uni.jobORM.FindJob(testutils.Context(t), uni.listener.JobID()) + drJob, jErr := uni.jobORM.FindJob(ctx, uni.listener.JobID()) require.NoError(t, jErr) require.Equal(t, drJob.ID, uni.listener.JobID()) require.NotNil(t, drJob.DirectRequestSpec) - uni.listener.HandleLog(log) + uni.listener.HandleLog(ctx, log) runBeganAwaiter.AwaitOrFail(t, 5*time.Second) @@ -433,14 +439,15 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { log.On("DecodedLog").Return(&logOracleRequest) log.On("String").Return("") markConsumedLogAwaiter := cltest.NewAwaiter() - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { markConsumedLogAwaiter.ItHappened() }).Return(nil) - err := uni.service.Start(testutils.Context(t)) + ctx := testutils.Context(t) + err := uni.service.Start(ctx) require.NoError(t, err) - uni.listener.HandleLog(log) + uni.listener.HandleLog(ctx, log) markConsumedLogAwaiter.AwaitOrFail(t, 5*time.Second) @@ -479,27 +486,28 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { log.On("DecodedLog").Return(&logOracleRequest) log.On("String").Return("") markConsumedLogAwaiter := cltest.NewAwaiter() - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { markConsumedLogAwaiter.ItHappened() }).Return(nil) runBeganAwaiter := cltest.NewAwaiter() uni.runner.On("Run", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { runBeganAwaiter.ItHappened() - fn := args.Get(4).(func(pg.Queryer) error) + fn := args.Get(4).(func(sqlutil.DataSource) error) require.NoError(t, fn(nil)) }).Once().Return(false, nil) - err := uni.service.Start(testutils.Context(t)) + ctx := testutils.Context(t) + err := uni.service.Start(ctx) require.NoError(t, err) // check if the job exists under the correct ID - drJob, jErr := uni.jobORM.FindJob(testutils.Context(t), uni.listener.JobID()) + drJob, jErr := uni.jobORM.FindJob(ctx, uni.listener.JobID()) require.NoError(t, jErr) require.Equal(t, drJob.ID, uni.listener.JobID()) require.NotNil(t, drJob.DirectRequestSpec) - uni.listener.HandleLog(log) + uni.listener.HandleLog(ctx, log) runBeganAwaiter.AwaitOrFail(t, 5*time.Second) @@ -534,14 +542,15 @@ func TestDelegate_ServicesListenerHandleLog(t *testing.T) { log.On("DecodedLog").Return(&logOracleRequest) log.On("String").Return("") markConsumedLogAwaiter := cltest.NewAwaiter() - uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + uni.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { markConsumedLogAwaiter.ItHappened() }).Return(nil) - err := uni.service.Start(testutils.Context(t)) + ctx := testutils.Context(t) + err := uni.service.Start(ctx) require.NoError(t, err) - uni.listener.HandleLog(log) + uni.listener.HandleLog(ctx, log) markConsumedLogAwaiter.AwaitOrFail(t, 5*time.Second) diff --git a/core/services/feeds/orm_test.go b/core/services/feeds/orm_test.go index 23f40b9d55c..51a85a33a46 100644 --- a/core/services/feeds/orm_test.go +++ b/core/services/feeds/orm_test.go @@ -1652,7 +1652,7 @@ func createJob(t *testing.T, db *sqlx.DB, externalJobID uuid.UUID) *job.Job { config = configtest.NewGeneralConfig(t, nil) keyStore = cltest.NewKeyStore(t, db, config.Database()) lggr = logger.TestLogger(t) - pipelineORM = pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM = pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgeORM = bridges.NewORM(db) relayExtenders = evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: config, KeyStore: keyStore.Eth()}) ) diff --git a/core/services/feeds/service.go b/core/services/feeds/service.go index aa6ccdb39d7..27d324d2342 100644 --- a/core/services/feeds/service.go +++ b/core/services/feeds/service.go @@ -1191,7 +1191,7 @@ func (s *service) newChainConfigMsg(cfg ChainConfig) (*pb.ChainConfig, error) { }, nil } -// newFMConfigMsg generates a FMConfig protobuf message. Flux Monitor does not +// newFluxMonitorConfigMsg generates a FMConfig protobuf message. Flux Monitor does not // have any configuration but this is here for consistency. func (*service) newFluxMonitorConfigMsg(cfg FluxMonitorConfig) *pb.FluxMonitorConfig { return &pb.FluxMonitorConfig{Enabled: cfg.Enabled} diff --git a/core/services/fluxmonitorv2/delegate.go b/core/services/fluxmonitorv2/delegate.go index 1e2eba8d000..ddb255800b1 100644 --- a/core/services/fluxmonitorv2/delegate.go +++ b/core/services/fluxmonitorv2/delegate.go @@ -13,7 +13,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" ) @@ -56,10 +55,10 @@ func (d *Delegate) JobType() job.Type { return job.FluxMonitor } -func (d *Delegate) BeforeJobCreated(spec job.Job) {} -func (d *Delegate) AfterJobCreated(spec job.Job) {} -func (d *Delegate) BeforeJobDeleted(spec job.Job) {} -func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) BeforeJobCreated(spec job.Job) {} +func (d *Delegate) AfterJobCreated(spec job.Job) {} +func (d *Delegate) BeforeJobDeleted(spec job.Job) {} +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // ServicesForSpec returns the flux monitor service for the job spec func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services []job.ServiceCtx, err error) { @@ -80,7 +79,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] fm, err := NewFromJobSpec( jb, d.db, - NewORM(d.db, d.lggr, chain.Config().Database(), chain.TxManager(), strategy, checker), + NewORM(d.db, d.lggr, chain.TxManager(), strategy, checker), d.jobORM, d.pipelineORM, NewKeyStore(d.ethKeyStore), @@ -89,10 +88,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] d.pipelineRunner, chain.Config().EVM(), chain.Config().EVM().GasEstimator(), - chain.Config().EVM().Transactions(), - chain.Config().FluxMonitor(), chain.Config().JobPipeline(), - chain.Config().Database(), d.lggr, ) if err != nil { diff --git a/core/services/fluxmonitorv2/flux_monitor.go b/core/services/fluxmonitorv2/flux_monitor.go index 73034faa3ce..5eebb319030 100644 --- a/core/services/fluxmonitorv2/flux_monitor.go +++ b/core/services/fluxmonitorv2/flux_monitor.go @@ -16,6 +16,7 @@ import ( "github.com/jmoiron/sqlx" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/bridges" evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" @@ -27,7 +28,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/recovery" "github.com/smartcontractkit/chainlink/v2/core/services/fluxmonitorv2/promfm" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/utils" ) @@ -64,7 +64,7 @@ type FluxMonitor struct { jobSpec job.Job spec pipeline.Spec runner pipeline.Runner - q pg.Q + ds sqlutil.DataSource orm ORM jobORM job.ORM pipelineORM pipeline.ORM @@ -93,7 +93,7 @@ func NewFluxMonitor( pipelineRunner pipeline.Runner, jobSpec job.Job, spec pipeline.Spec, - q pg.Q, + ds sqlutil.DataSource, orm ORM, jobORM job.ORM, pipelineORM pipeline.ORM, @@ -111,7 +111,7 @@ func NewFluxMonitor( chainID *big.Int, ) (*FluxMonitor, error) { fm := &FluxMonitor{ - q: q, + ds: ds, runner: pipelineRunner, jobSpec: jobSpec, spec: spec, @@ -159,10 +159,7 @@ func NewFromJobSpec( pipelineRunner pipeline.Runner, cfg Config, fcfg EvmFeeConfig, - ecfg EvmTransactionsConfig, - fmcfg FluxMonitorConfig, jcfg JobPipelineConfig, - dbCfg pg.QConfig, lggr logger.Logger, ) (*FluxMonitor, error) { fmSpec := jobSpec.FluxMonitorSpec @@ -253,7 +250,7 @@ func NewFromJobSpec( pipelineRunner, jobSpec, *jobSpec.PipelineSpec, - pg.NewQ(db, lggr, dbCfg), + db, orm, jobORM, pipelineORM, @@ -325,7 +322,7 @@ func (fm *FluxMonitor) Close() error { func (fm *FluxMonitor) JobID() int32 { return fm.spec.JobID } // HandleLog processes the contract logs -func (fm *FluxMonitor) HandleLog(broadcast log.Broadcast) { +func (fm *FluxMonitor) HandleLog(ctx context.Context, broadcast log.Broadcast) { log := broadcast.DecodedLog() if log == nil || reflect.ValueOf(log).IsNil() { fm.logger.Panic("HandleLog: failed to handle log of type nil") @@ -509,15 +506,16 @@ func (fm *FluxMonitor) SetOracleAddress() error { } func (fm *FluxMonitor) processLogs() { - for !fm.backlog.Empty() { + ctx, cancel := fm.chStop.NewCtx() + defer cancel() + + for ctx.Err() == nil && !fm.backlog.Empty() { broadcast := fm.backlog.Take() - fm.processBroadcast(broadcast) + fm.processBroadcast(ctx, broadcast) } } -func (fm *FluxMonitor) processBroadcast(broadcast log.Broadcast) { - ctx, cancel := fm.chStop.NewCtx() - defer cancel() +func (fm *FluxMonitor) processBroadcast(ctx context.Context, broadcast log.Broadcast) { // If the log is a duplicate of one we've seen before, ignore it (this // happens because of the LogBroadcaster's backfilling behavior). consumed, err := fm.logBroadcaster.WasAlreadyConsumed(ctx, broadcast) @@ -553,7 +551,7 @@ func (fm *FluxMonitor) processBroadcast(broadcast log.Broadcast) { } func (fm *FluxMonitor) markLogAsConsumed(ctx context.Context, broadcast log.Broadcast, decodedLog interface{}, started time.Time) { - if err := fm.logBroadcaster.MarkConsumed(ctx, broadcast); err != nil { + if err := fm.logBroadcaster.MarkConsumed(ctx, nil, broadcast); err != nil { fm.logger.Errorw("Failed to mark log as consumed", "err", err, "logType", fmt.Sprintf("%T", decodedLog), "log", broadcast.String(), "elapsed", time.Since(started)) } @@ -608,7 +606,7 @@ func (fm *FluxMonitor) respondToNewRoundLog(log flux_aggregator_wrapper.FluxAggr var markConsumed = true defer func() { if markConsumed { - if err := fm.logBroadcaster.MarkConsumed(ctx, lb); err != nil { + if err := fm.logBroadcaster.MarkConsumed(ctx, nil, lb); err != nil { fm.logger.Errorw("Failed to mark log consumed", "err", err, "log", lb.String()) } } @@ -665,13 +663,13 @@ func (fm *FluxMonitor) respondToNewRoundLog(log flux_aggregator_wrapper.FluxAggr // We always want to reset the idle timer upon receiving a NewRound log, so we do it before any `return` statements. fm.pollManager.ResetIdleTimer(log.StartedAt.Uint64()) - mostRecentRoundID, err := fm.orm.MostRecentFluxMonitorRoundID(fm.contractAddress) + mostRecentRoundID, err := fm.orm.MostRecentFluxMonitorRoundID(ctx, fm.contractAddress) if err != nil && !errors.Is(err, sql.ErrNoRows) { newRoundLogger.Errorf("error fetching Flux Monitor most recent round ID from DB: %v", err) return } - roundStats, jobRunStatus, err := fm.statsAndStatusForRound(logRoundID, 1) + roundStats, jobRunStatus, err := fm.statsAndStatusForRound(ctx, logRoundID, 1) if err != nil { newRoundLogger.Errorf("error determining round stats / run status for round: %v", err) return @@ -680,14 +678,14 @@ func (fm *FluxMonitor) respondToNewRoundLog(log flux_aggregator_wrapper.FluxAggr if logRoundID < mostRecentRoundID && roundStats.NumNewRoundLogs > 0 { newRoundLogger.Debugf("Received an older round log (and number of previously received NewRound logs is: %v) - "+ "a possible reorg, hence deleting round ids from %v to %v", roundStats.NumNewRoundLogs, logRoundID, mostRecentRoundID) - err = fm.orm.DeleteFluxMonitorRoundsBackThrough(fm.contractAddress, logRoundID) + err = fm.orm.DeleteFluxMonitorRoundsBackThrough(ctx, fm.contractAddress, logRoundID) if err != nil { newRoundLogger.Errorf("error deleting reorged Flux Monitor rounds from DB: %v", err) return } // as all newer stats were deleted, at this point a new round stats entry will be created - roundStats, err = fm.orm.FindOrCreateFluxMonitorRoundStats(fm.contractAddress, logRoundID, 1) + roundStats, err = fm.orm.FindOrCreateFluxMonitorRoundStats(ctx, fm.contractAddress, logRoundID, 1) if err != nil { newRoundLogger.Errorf("error determining subsequent round stats for round: %v", err) return @@ -771,7 +769,7 @@ func (fm *FluxMonitor) respondToNewRoundLog(log flux_aggregator_wrapper.FluxAggr return } - if !fm.isValidSubmission(newRoundLogger, answer, started) { + if !fm.isValidSubmission(ctx, newRoundLogger, answer, started) { return } @@ -779,14 +777,14 @@ func (fm *FluxMonitor) respondToNewRoundLog(log flux_aggregator_wrapper.FluxAggr newRoundLogger.Error("roundState.PaymentAmount shouldn't be nil") } - err = fm.q.Transaction(func(tx pg.Queryer) error { - if err2 := fm.runner.InsertFinishedRun(run, false, pg.WithQueryer(tx)); err2 != nil { + err = fm.Transact(ctx, func(tx sqlutil.DataSource) error { + if err2 := fm.runner.InsertFinishedRun(ctx, tx, run, false); err2 != nil { return err2 } if err2 := fm.queueTransactionForTxm(ctx, tx, run.ID, answer, roundState.RoundId, &log); err2 != nil { return err2 } - return fm.logBroadcaster.MarkConsumed(ctx, lb) + return fm.logBroadcaster.MarkConsumed(ctx, tx, lb) }) // Either the tx failed and we want to reprocess the log, or it succeeded and already marked it consumed markConsumed = false @@ -796,6 +794,10 @@ func (fm *FluxMonitor) respondToNewRoundLog(log flux_aggregator_wrapper.FluxAggr } } +func (fm *FluxMonitor) Transact(ctx context.Context, fn func(sqlutil.DataSource) error) error { + return sqlutil.TransactDataSource(ctx, fm.ds, nil, fn) +} + var ( // ErrNotEligible defines when the round is not eligible for submission ErrNotEligible = errors.New("not eligible to submit") @@ -832,7 +834,7 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker var markConsumed = true defer func() { if markConsumed && broadcast != nil { - if err := fm.logBroadcaster.MarkConsumed(ctx, broadcast); err != nil { + if err := fm.logBroadcaster.MarkConsumed(ctx, nil, broadcast); err != nil { l.Errorw("Failed to mark log consumed", "err", err, "log", broadcast.String()) } } @@ -863,8 +865,7 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker roundState, err := fm.roundState(0) if err != nil { l.Errorw("unable to determine eligibility to submit from FluxAggregator contract", "err", err) - fm.jobORM.TryRecordError( - fm.spec.JobID, + fm.jobORM.TryRecordError(fm.spec.JobID, "Unable to call roundState method on provided contract. Check contract address.", ) @@ -884,8 +885,7 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker roundStateNew, err2 := fm.roundState(roundState.RoundId) if err2 != nil { l.Errorw("unable to determine eligibility to submit from FluxAggregator contract", "err", err2) - fm.jobORM.TryRecordError( - fm.spec.JobID, + fm.jobORM.TryRecordError(fm.spec.JobID, "Unable to call roundState method on provided contract. Check contract address.", ) @@ -909,7 +909,7 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker } }() - roundStats, jobRunStatus, err := fm.statsAndStatusForRound(roundState.RoundId, 0) + roundStats, jobRunStatus, err := fm.statsAndStatusForRound(ctx, roundState.RoundId, 0) if err != nil { l.Errorw("error determining round stats / run status for round", "err", err) @@ -977,7 +977,7 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker return } - if !fm.isValidSubmission(l, answer, started) { + if !fm.isValidSubmission(ctx, l, answer, started) { return } @@ -1005,8 +1005,8 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker l.Error("roundState.PaymentAmount shouldn't be nil") } - err = fm.q.Transaction(func(tx pg.Queryer) error { - if err2 := fm.runner.InsertFinishedRun(run, true, pg.WithQueryer(tx)); err2 != nil { + err = fm.Transact(ctx, func(tx sqlutil.DataSource) error { + if err2 := fm.runner.InsertFinishedRun(ctx, tx, run, true); err2 != nil { return err2 } if err2 := fm.queueTransactionForTxm(ctx, tx, run.ID, answer, roundState.RoundId, nil); err2 != nil { @@ -1014,7 +1014,7 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker } if broadcast != nil { // In the case of a flag lowered, the pollEligible call is triggered by a log. - return fm.logBroadcaster.MarkConsumed(ctx, broadcast) + return fm.logBroadcaster.MarkConsumed(ctx, tx, broadcast) } return nil }) @@ -1031,7 +1031,7 @@ func (fm *FluxMonitor) pollIfEligible(pollReq PollRequestType, deviationChecker // If the answer is outside the allowable range, log an error and don't submit. // to avoid an onchain reversion. -func (fm *FluxMonitor) isValidSubmission(l logger.Logger, answer decimal.Decimal, started time.Time) bool { +func (fm *FluxMonitor) isValidSubmission(ctx context.Context, l logger.Logger, answer decimal.Decimal, started time.Time) bool { if fm.submissionChecker.IsValid(answer) { return true } @@ -1085,7 +1085,7 @@ func (fm *FluxMonitor) initialRoundState() flux_aggregator_wrapper.OracleRoundSt return latestRoundState } -func (fm *FluxMonitor) queueTransactionForTxm(ctx context.Context, tx pg.Queryer, runID int64, answer decimal.Decimal, roundID uint32, log *flux_aggregator_wrapper.FluxAggregatorNewRound) error { +func (fm *FluxMonitor) queueTransactionForTxm(ctx context.Context, tx sqlutil.DataSource, runID int64, answer decimal.Decimal, roundID uint32, log *flux_aggregator_wrapper.FluxAggregatorNewRound) error { // Use pipeline run ID to generate globally unique key that can correlate this run to a Tx idempotencyKey := fmt.Sprintf("fluxmonitor-%d", runID) // Submit the Eth Tx @@ -1105,12 +1105,12 @@ func (fm *FluxMonitor) queueTransactionForTxm(ctx context.Context, tx pg.Queryer numLogs = 1 } // Update the flux monitor round stats - err = fm.orm.UpdateFluxMonitorRoundStats( + err = fm.orm.WithDataSource(tx).UpdateFluxMonitorRoundStats( + ctx, fm.contractAddress, roundID, runID, numLogs, - pg.WithQueryer(tx), ) if err != nil { fm.logger.Errorw( @@ -1124,8 +1124,8 @@ func (fm *FluxMonitor) queueTransactionForTxm(ctx context.Context, tx pg.Queryer return nil } -func (fm *FluxMonitor) statsAndStatusForRound(roundID uint32, newRoundLogs uint) (FluxMonitorRoundStatsV2, pipeline.RunStatus, error) { - roundStats, err := fm.orm.FindOrCreateFluxMonitorRoundStats(fm.contractAddress, roundID, newRoundLogs) +func (fm *FluxMonitor) statsAndStatusForRound(ctx context.Context, roundID uint32, newRoundLogs uint) (FluxMonitorRoundStatsV2, pipeline.RunStatus, error) { + roundStats, err := fm.orm.FindOrCreateFluxMonitorRoundStats(ctx, fm.contractAddress, roundID, newRoundLogs) if err != nil { return FluxMonitorRoundStatsV2{}, pipeline.RunStatusUnknown, err } @@ -1133,7 +1133,7 @@ func (fm *FluxMonitor) statsAndStatusForRound(roundID uint32, newRoundLogs uint) // JobRun will not exist if this is the first time responding to this round var run pipeline.Run if roundStats.PipelineRunID.Valid { - run, err = fm.pipelineORM.FindRun(roundStats.PipelineRunID.Int64) + run, err = fm.pipelineORM.FindRun(ctx, roundStats.PipelineRunID.Int64) if err != nil { return FluxMonitorRoundStatsV2{}, pipeline.RunStatusUnknown, err } diff --git a/core/services/fluxmonitorv2/flux_monitor_test.go b/core/services/fluxmonitorv2/flux_monitor_test.go index e4db716bbbb..042ddb99afb 100644 --- a/core/services/fluxmonitorv2/flux_monitor_test.go +++ b/core/services/fluxmonitorv2/flux_monitor_test.go @@ -22,6 +22,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/assets" "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" txmgrcommon "github.com/smartcontractkit/chainlink/v2/common/txmgr" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/log" logmocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/log/mocks" @@ -31,7 +32,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/internal/cltest/heavyweight" "github.com/smartcontractkit/chainlink/v2/core/internal/mocks" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" - "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" corenull "github.com/smartcontractkit/chainlink/v2/core/null" @@ -40,7 +40,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/job" jobmocks "github.com/smartcontractkit/chainlink/v2/core/services/job/mocks" "github.com/smartcontractkit/chainlink/v2/core/services/keystore/keys/ethkey" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" pipelinemocks "github.com/smartcontractkit/chainlink/v2/core/services/pipeline/mocks" ) @@ -53,8 +52,8 @@ var ( type answerSet struct{ latestAnswer, polledAnswer int64 } -func newORM(t *testing.T, db *sqlx.DB, cfg pg.QConfig, txm txmgr.TxManager) fluxmonitorv2.ORM { - return fluxmonitorv2.NewORM(db, logger.TestLogger(t), cfg, txm, txmgrcommon.NewSendEveryStrategy(), txmgr.TransmitCheckerSpec{}) +func newORM(t *testing.T, db *sqlx.DB, txm txmgr.TxManager) fluxmonitorv2.ORM { + return fluxmonitorv2.NewORM(db, logger.TestLogger(t), txm, txmgrcommon.NewSendEveryStrategy(), txmgr.TransmitCheckerSpec{}) } var ( @@ -149,7 +148,7 @@ type setupOptions struct { // setup sets up a Flux Monitor for testing, allowing the test to provide // functional options to configure the setup -func setup(t *testing.T, db *sqlx.DB, optionFns ...func(*setupOptions)) (*fluxmonitorv2.FluxMonitor, *testMocks) { +func setup(t *testing.T, ds sqlutil.DataSource, optionFns ...func(*setupOptions)) (*fluxmonitorv2.FluxMonitor, *testMocks) { t.Helper() testutils.SkipShort(t, "long test") @@ -190,7 +189,7 @@ func setup(t *testing.T, db *sqlx.DB, optionFns ...func(*setupOptions)) (*fluxmo tm.pipelineRunner, job.Job{}, pipelineSpec, - pg.NewQ(db, lggr, pgtest.NewQConfig(true)), + ds, options.orm, tm.jobORM, tm.pipelineORM, @@ -386,7 +385,7 @@ func TestFluxMonitor_PollIfEligible(t *testing.T) { } tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(reportableRoundID), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(reportableRoundID), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: reportableRoundID, @@ -395,12 +394,12 @@ func TestFluxMonitor_PollIfEligible(t *testing.T) { }, nil) tm.pipelineORM. - On("FindRun", run.ID). + On("FindRun", mock.Anything, run.ID). Return(run, nil) } else { if tc.connected { tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(reportableRoundID), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(reportableRoundID), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: reportableRoundID, @@ -469,7 +468,7 @@ func TestFluxMonitor_PollIfEligible(t *testing.T) { tm.pipelineRunner.On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = 1 + args.Get(2).(*pipeline.Run).ID = 1 }). Once() tm.contractSubmitter. @@ -479,13 +478,14 @@ func TestFluxMonitor_PollIfEligible(t *testing.T) { tm.orm. On("UpdateFluxMonitorRoundStats", + mock.Anything, contractAddress, uint32(reportableRoundID), int64(1), mock.Anything, - mock.Anything, ). Return(nil) + tm.orm.On("WithDataSource", mock.Anything).Return(fluxmonitorv2.ORM(tm.orm)) } oracles := []common.Address{nodeAddr, testutils.NewAddress()} @@ -560,6 +560,7 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { logsAwaiter := cltest.NewAwaiter() tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() + tm.orm.On("WithDataSource", mock.Anything).Return(fluxmonitorv2.ORM(tm.orm)) tm.fluxAggregator.On("Address").Return(common.Address{}) tm.fluxAggregator.On("LatestRoundData", nilOpts).Return(freshContractRoundDataResponse()).Maybe() @@ -573,19 +574,18 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { tm.fluxAggregator.On("OracleRoundState", nilOpts, nodeAddr, uint32(3)).Return(makeRoundStateForRoundID(3), nil).Once() tm.fluxAggregator.On("OracleRoundState", nilOpts, nodeAddr, uint32(4)).Return(makeRoundStateForRoundID(4), nil).Once() tm.fluxAggregator.On("GetOracles", nilOpts).Return(oracles, nil) - // tm.fluxAggregator.On("Address").Return(contractAddress, nil) tm.logBroadcaster.On("Register", fm, mock.Anything).Return(func() {}) tm.logBroadcaster.On("IsConnected").Return(true).Maybe() - tm.orm.On("MostRecentFluxMonitorRoundID", contractAddress).Return(uint32(1), nil) - tm.orm.On("MostRecentFluxMonitorRoundID", contractAddress).Return(uint32(3), nil) - tm.orm.On("MostRecentFluxMonitorRoundID", contractAddress).Return(uint32(4), nil) + tm.orm.On("MostRecentFluxMonitorRoundID", mock.Anything, contractAddress).Return(uint32(1), nil) + tm.orm.On("MostRecentFluxMonitorRoundID", mock.Anything, contractAddress).Return(uint32(3), nil) + tm.orm.On("MostRecentFluxMonitorRoundID", mock.Anything, contractAddress).Return(uint32(4), nil) // Round 1 run := &pipeline.Run{ID: 1} tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(1), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(1), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 1, @@ -605,7 +605,7 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = 1 + args.Get(2).(*pipeline.Run).ID = 1 }).Once() tm.contractSubmitter. On("Submit", mock.Anything, big.NewInt(1), big.NewInt(fetchedValue), buildIdempotencyKey(run.ID)). @@ -614,18 +614,18 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { tm.orm. On("UpdateFluxMonitorRoundStats", + mock.Anything, contractAddress, uint32(1), mock.AnythingOfType("int64"), //int64(1), mock.Anything, - mock.Anything, ). Return(nil).Once() // Round 3 run = &pipeline.Run{ID: 2} tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(3), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(3), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 3, @@ -645,7 +645,7 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = 2 + args.Get(2).(*pipeline.Run).ID = 2 }).Once() tm.contractSubmitter. On("Submit", mock.Anything, big.NewInt(3), big.NewInt(fetchedValue), buildIdempotencyKey(run.ID)). @@ -653,18 +653,18 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { Once() tm.orm. On("UpdateFluxMonitorRoundStats", + mock.Anything, contractAddress, uint32(3), mock.AnythingOfType("int64"), //int64(2), mock.Anything, - mock.Anything, ). Return(nil).Once() // Round 4 run = &pipeline.Run{ID: 3} tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(4), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(4), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 3, @@ -684,7 +684,7 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = 3 + args.Get(2).(*pipeline.Run).ID = 3 }).Once() tm.contractSubmitter. On("Submit", mock.Anything, big.NewInt(4), big.NewInt(fetchedValue), buildIdempotencyKey(run.ID)). @@ -692,11 +692,11 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { Once() tm.orm. On("UpdateFluxMonitorRoundStats", + mock.Anything, contractAddress, uint32(4), mock.AnythingOfType("int64"), //int64(3), mock.Anything, - mock.Anything, ). Return(nil). Once(). @@ -711,17 +711,17 @@ func TestPollingDeviationChecker_BuffersLogs(t *testing.T) { logBroadcast.On("DecodedLog").Return(&flux_aggregator_wrapper.FluxAggregatorNewRound{RoundId: big.NewInt(int64(i)), StartedAt: big.NewInt(0)}) logBroadcast.On("String").Maybe().Return("") tm.logBroadcaster.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) logBroadcasts = append(logBroadcasts, logBroadcast) } - - fm.HandleLog(logBroadcasts[0]) // Get the checker to start processing a log so we can freeze it + ctx := testutils.Context(t) + fm.HandleLog(ctx, logBroadcasts[0]) // Get the checker to start processing a log so we can freeze it readyToFillQueue.AwaitOrFail(t) - fm.HandleLog(logBroadcasts[1]) // This log is evicted from the priority queue - fm.HandleLog(logBroadcasts[2]) - fm.HandleLog(logBroadcasts[3]) + fm.HandleLog(ctx, logBroadcasts[1]) // This log is evicted from the priority queue + fm.HandleLog(ctx, logBroadcasts[2]) + fm.HandleLog(ctx, logBroadcasts[3]) logsAwaiter.ItHappened() readyToAssert.AwaitOrFail(t) @@ -749,7 +749,7 @@ func TestFluxMonitor_TriggerIdleTimeThreshold(t *testing.T) { t.Parallel() var ( - orm = newORM(t, db, pgtest.NewQConfig(true), nil) + orm = newORM(t, db, nil) ) fm, tm := setup(t, db, disablePollTicker(true), disableIdleTimer(tc.idleTimerDisabled), setIdleTimerPeriod(tc.idleDuration), withORM(orm)) @@ -795,8 +795,8 @@ func TestFluxMonitor_TriggerIdleTimeThreshold(t *testing.T) { tm.logBroadcast.On("DecodedLog").Return(&decodedLog) tm.logBroadcast.On("String").Maybe().Return("") tm.logBroadcaster.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) - fm.HandleLog(tm.logBroadcast) + tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) + fm.HandleLog(testutils.Context(t), tm.logBroadcast) g.Eventually(chBlock).Should(gomega.BeClosed()) @@ -856,7 +856,7 @@ func TestFluxMonitor_HibernationTickerFiresMultipleTimes(t *testing.T) { pollOccured <- struct{}{} }) tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(1), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(1), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 1, @@ -873,7 +873,7 @@ func TestFluxMonitor_HibernationTickerFiresMultipleTimes(t *testing.T) { // Finds an existing run created by the initial poll tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(1), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(1), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ PipelineRunID: corenull.NewInt64(int64(1), true), Aggregator: contractAddress, @@ -881,7 +881,7 @@ func TestFluxMonitor_HibernationTickerFiresMultipleTimes(t *testing.T) { NumSubmissions: 1, }, nil).Once() finishedAt := time.Now() - tm.pipelineORM.On("FindRun", int64(1)).Return(pipeline.Run{ + tm.pipelineORM.On("FindRun", mock.Anything, int64(1)).Return(pipeline.Run{ FinishedAt: null.TimeFrom(finishedAt), }, nil) @@ -893,7 +893,7 @@ func TestFluxMonitor_HibernationTickerFiresMultipleTimes(t *testing.T) { pollOccured <- struct{}{} }) tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(2), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(2), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 2, @@ -950,7 +950,7 @@ func TestFluxMonitor_HibernationIsEnteredAndRetryTickerStopped(t *testing.T) { pollOccured <- struct{}{} }) tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, roundOne, mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, roundOne, mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 1, @@ -970,7 +970,7 @@ func TestFluxMonitor_HibernationIsEnteredAndRetryTickerStopped(t *testing.T) { // Finds an error run, so that retry ticker will be kicked off tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, roundOne, mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, roundOne, mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ PipelineRunID: corenull.NewInt64(int64(1), true), Aggregator: contractAddress, @@ -978,7 +978,7 @@ func TestFluxMonitor_HibernationIsEnteredAndRetryTickerStopped(t *testing.T) { NumSubmissions: 1, }, nil).Once() finishedAt := time.Now() - tm.pipelineORM.On("FindRun", int64(1)).Return(pipeline.Run{ + tm.pipelineORM.On("FindRun", mock.Anything, int64(1)).Return(pipeline.Run{ FinishedAt: null.TimeFrom(finishedAt), FatalErrors: []null.String{null.StringFrom("an error to start retry ticker")}, }, nil) @@ -997,7 +997,7 @@ func TestFluxMonitor_HibernationIsEnteredAndRetryTickerStopped(t *testing.T) { roundState2 := flux_aggregator_wrapper.OracleRoundState{RoundId: 2, EligibleToSubmit: false, LatestSubmission: answerBigInt, StartedAt: 0} tm.fluxAggregator.On("OracleRoundState", nilOpts, nodeAddr, roundZero).Return(roundState2, nil).Once() tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, roundTwo, mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, roundTwo, mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 2, @@ -1054,7 +1054,7 @@ func TestFluxMonitor_IdleTimerResetsOnNewRound(t *testing.T) { roundState1 := flux_aggregator_wrapper.OracleRoundState{RoundId: 1, EligibleToSubmit: false, LatestSubmission: answerBigInt, StartedAt: now()} tm.fluxAggregator.On("OracleRoundState", nilOpts, nodeAddr, uint32(0)).Return(roundState1, nil).Once() tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(1), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(1), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 1, @@ -1072,7 +1072,7 @@ func TestFluxMonitor_IdleTimerResetsOnNewRound(t *testing.T) { }) // Finds an existing run created by the initial poll tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(1), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(1), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ PipelineRunID: corenull.NewInt64(int64(1), true), Aggregator: contractAddress, @@ -1080,7 +1080,7 @@ func TestFluxMonitor_IdleTimerResetsOnNewRound(t *testing.T) { NumSubmissions: 1, }, nil).Once() finishedAt := time.Now() - tm.pipelineORM.On("FindRun", int64(1)).Return(pipeline.Run{ + tm.pipelineORM.On("FindRun", mock.Anything, int64(1)).Return(pipeline.Run{ FinishedAt: null.TimeFrom(finishedAt), }, nil) @@ -1092,7 +1092,7 @@ func TestFluxMonitor_IdleTimerResetsOnNewRound(t *testing.T) { idleDurationOccured <- struct{}{} }) tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(2), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(2), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 2, @@ -1107,7 +1107,7 @@ func TestFluxMonitor_IdleTimerResetsOnNewRound(t *testing.T) { idleDurationOccured <- struct{}{} }) tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(3), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(3), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: 3, @@ -1118,7 +1118,7 @@ func TestFluxMonitor_IdleTimerResetsOnNewRound(t *testing.T) { tm.logBroadcaster.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil).Once() tm.logBroadcast.On("DecodedLog").Return(&flux_aggregator_wrapper.FluxAggregatorAnswerUpdated{}) tm.logBroadcast.On("String").Maybe().Return("") - tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil).Once() + tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() fm.ExportedBacklog().Add(fluxmonitorv2.PriorityNewRoundLog, tm.logBroadcast) fm.ExportedProcessLogs() @@ -1133,7 +1133,7 @@ func TestFluxMonitor_RoundTimeoutCausesPoll_timesOutAtZero(t *testing.T) { var ( oracles = []common.Address{nodeAddr, testutils.NewAddress()} - orm = newORM(t, db, pgtest.NewQConfig(true), nil) + orm = newORM(t, db, nil) ) fm, tm := setup(t, db, disablePollTicker(true), disableIdleTimer(true), withORM(orm)) @@ -1193,10 +1193,7 @@ func TestFluxMonitor_UsesPreviousRoundStateOnStartup_RoundTimeout(t *testing.T) t.Run(test.name, func(t *testing.T) { t.Parallel() - cfg := configtest.NewTestGeneralConfig(t) - var ( - orm = newORM(t, db, cfg.Database(), nil) - ) + orm := newORM(t, db, nil) fm, tm := setup(t, db, disablePollTicker(true), disableIdleTimer(true), withORM(orm)) @@ -1260,11 +1257,7 @@ func TestFluxMonitor_UsesPreviousRoundStateOnStartup_IdleTimer(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - cfg := configtest.NewTestGeneralConfig(t) - - var ( - orm = newORM(t, db, cfg.Database(), nil) - ) + orm := newORM(t, db, nil) fm, tm := setup(t, db, @@ -1323,11 +1316,7 @@ func TestFluxMonitor_RoundTimeoutCausesPoll_timesOutNotZero(t *testing.T) { g := gomega.NewWithT(t) db, nodeAddr := setupStoreWithKey(t) oracles := []common.Address{nodeAddr, testutils.NewAddress()} - cfg := configtest.NewTestGeneralConfig(t) - - var ( - orm = newORM(t, db, cfg.Database(), nil) - ) + orm := newORM(t, db, nil) fm, tm := setup(t, db, disablePollTicker(true), disableIdleTimer(true), withORM(orm)) @@ -1381,14 +1370,14 @@ func TestFluxMonitor_RoundTimeoutCausesPoll_timesOutNotZero(t *testing.T) { servicetest.Run(t, fm) tm.logBroadcaster.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) tm.logBroadcast.On("DecodedLog").Return(&flux_aggregator_wrapper.FluxAggregatorNewRound{ RoundId: big.NewInt(0), StartedAt: big.NewInt(time.Now().UTC().Unix()), }) tm.logBroadcast.On("String").Maybe().Return("") // To mark it consumed, we need to be eligible to submit. - fm.HandleLog(tm.logBroadcast) + fm.HandleLog(testutils.Context(t), tm.logBroadcast) g.Eventually(chRoundState1).Should(gomega.BeClosed()) g.Eventually(chRoundState2).Should(gomega.BeClosed()) @@ -1409,7 +1398,7 @@ func TestFluxMonitor_ConsumeLogBroadcast(t *testing.T) { tm.logBroadcaster.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil).Once() tm.logBroadcast.On("DecodedLog").Return(&flux_aggregator_wrapper.FluxAggregatorAnswerUpdated{}) tm.logBroadcast.On("String").Maybe().Return("") - tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil).Once() + tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() fm.ExportedBacklog().Add(fluxmonitorv2.PriorityNewRoundLog, tm.logBroadcast) fm.ExportedProcessLogs() @@ -1468,11 +1457,12 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() tm.logBroadcaster.On("IsConnected").Return(true).Maybe() + tm.orm.On("WithDataSource", mock.Anything).Return(fluxmonitorv2.ORM(tm.orm)) // Mocks initiated by the New Round log - tm.orm.On("MostRecentFluxMonitorRoundID", contractAddress).Return(uint32(roundID), nil).Once() + tm.orm.On("MostRecentFluxMonitorRoundID", mock.Anything, contractAddress).Return(uint32(roundID), nil).Once() tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(roundID), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(roundID), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: roundID, @@ -1492,17 +1482,17 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = 1 + args.Get(2).(*pipeline.Run).ID = 1 }) - tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil).Once() + tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() tm.contractSubmitter.On("Submit", mock.Anything, big.NewInt(roundID), big.NewInt(answer), buildIdempotencyKey(run.ID)).Return(nil).Once() tm.orm. On("UpdateFluxMonitorRoundStats", + mock.Anything, contractAddress, uint32(roundID), int64(1), uint(1), - mock.Anything, ). Return(nil) @@ -1545,7 +1535,7 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { Once() tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(roundID), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(roundID), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ PipelineRunID: corenull.NewInt64(int64(1), true), Aggregator: contractAddress, @@ -1554,7 +1544,7 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { }, nil).Once() now := time.Now() - tm.pipelineORM.On("FindRun", int64(1)).Return(pipeline.Run{ + tm.pipelineORM.On("FindRun", mock.Anything, int64(1)).Return(pipeline.Run{ FinishedAt: null.TimeFrom(now), }, nil) @@ -1583,6 +1573,7 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { run := &pipeline.Run{ID: 1} tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() tm.logBroadcaster.On("IsConnected").Return(true).Maybe() + tm.orm.On("WithDataSource", mock.Anything).Return(fluxmonitorv2.ORM(tm.orm)) // First, force the node to try to poll, which should result in a submission tm.fluxAggregator.On("LatestRoundData", nilOpts).Return(flux_aggregator_wrapper.LatestRoundData{ @@ -1600,7 +1591,7 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { }, nil). Once() tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(roundID), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(roundID), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: roundID, @@ -1620,16 +1611,16 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = 1 + args.Get(2).(*pipeline.Run).ID = 1 }) tm.contractSubmitter.On("Submit", mock.Anything, big.NewInt(roundID), big.NewInt(answer), buildIdempotencyKey(run.ID)).Return(nil).Once() tm.orm. On("UpdateFluxMonitorRoundStats", + mock.Anything, contractAddress, uint32(roundID), int64(1), uint(0), - mock.Anything, ). Return(nil). Once() @@ -1639,18 +1630,18 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { fm.ExportedPollIfEligible(0, 0) // Now fire off the NewRound log and ensure it does not respond this time - tm.orm.On("MostRecentFluxMonitorRoundID", contractAddress).Return(uint32(roundID), nil) + tm.orm.On("MostRecentFluxMonitorRoundID", mock.Anything, contractAddress).Return(uint32(roundID), nil) tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(roundID), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(roundID), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ PipelineRunID: corenull.NewInt64(int64(1), true), Aggregator: contractAddress, RoundID: roundID, NumSubmissions: 1, }, nil).Once() - tm.pipelineORM.On("FindRun", int64(1)).Return(pipeline.Run{}, nil) + tm.pipelineORM.On("FindRun", mock.Anything, int64(1)).Return(pipeline.Run{}, nil) - tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) fm.ExportedRespondToNewRoundLog(&flux_aggregator_wrapper.FluxAggregatorNewRound{ RoundId: big.NewInt(roundID), StartedAt: big.NewInt(0), @@ -1679,6 +1670,7 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { run := &pipeline.Run{ID: 1} tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil).Once() tm.logBroadcaster.On("IsConnected").Return(true).Maybe() + tm.orm.On("WithDataSource", mock.Anything).Return(fluxmonitorv2.ORM(tm.orm)) // First, force the node to try to poll, which should result in a submission tm.fluxAggregator.On("LatestRoundData", nilOpts).Return(flux_aggregator_wrapper.LatestRoundData{ @@ -1696,7 +1688,7 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { }, nil). Once() tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(roundID), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(roundID), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ Aggregator: contractAddress, RoundID: roundID, @@ -1716,16 +1708,16 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = 1 + args.Get(2).(*pipeline.Run).ID = 1 }) tm.contractSubmitter.On("Submit", mock.Anything, big.NewInt(roundID), big.NewInt(answer), buildIdempotencyKey(run.ID)).Return(nil).Once() tm.orm. On("UpdateFluxMonitorRoundStats", + mock.Anything, contractAddress, uint32(roundID), int64(1), uint(0), - mock.Anything, ). Return(nil). Once() @@ -1735,27 +1727,27 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { fm.ExportedPollIfEligible(0, 0) // Now fire off the NewRound log and ensure it does not respond this time - tm.orm.On("MostRecentFluxMonitorRoundID", contractAddress).Return(uint32(roundID), nil) + tm.orm.On("MostRecentFluxMonitorRoundID", mock.Anything, contractAddress).Return(uint32(roundID), nil) tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(olderRoundID), mock.Anything). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(olderRoundID), mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ PipelineRunID: corenull.NewInt64(int64(1), true), Aggregator: contractAddress, RoundID: olderRoundID, NumSubmissions: 1, }, nil).Once() - tm.pipelineORM.On("FindRun", int64(1)).Return(pipeline.Run{}, nil) + tm.pipelineORM.On("FindRun", mock.Anything, int64(1)).Return(pipeline.Run{}, nil) - tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) fm.ExportedRespondToNewRoundLog(&flux_aggregator_wrapper.FluxAggregatorNewRound{ RoundId: big.NewInt(olderRoundID), StartedAt: big.NewInt(0), }, log.NewLogBroadcast(types.Log{}, cltest.FixtureChainID, nil)) // Simulate a reorg - fire the same NewRound log again, which should result in a submission this time - tm.orm.On("MostRecentFluxMonitorRoundID", contractAddress).Return(uint32(roundID), nil) + tm.orm.On("MostRecentFluxMonitorRoundID", mock.Anything, contractAddress).Return(uint32(roundID), nil) tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(olderRoundID), uint(1)). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(olderRoundID), uint(1)). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ PipelineRunID: corenull.NewInt64(int64(1), true), Aggregator: contractAddress, @@ -1763,14 +1755,14 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { NumSubmissions: 1, NumNewRoundLogs: 1, }, nil).Once() - tm.pipelineORM.On("FindRun", int64(1)).Return(pipeline.Run{}, nil) + tm.pipelineORM.On("FindRun", mock.Anything, int64(1)).Return(pipeline.Run{}, nil) // all newer round stats should be deleted - tm.orm.On("DeleteFluxMonitorRoundsBackThrough", contractAddress, uint32(olderRoundID)).Return(nil) + tm.orm.On("DeleteFluxMonitorRoundsBackThrough", mock.Anything, contractAddress, uint32(olderRoundID)).Return(nil) // then we are returning a fresh round stat, with NumSubmissions: 0 tm.orm. - On("FindOrCreateFluxMonitorRoundStats", contractAddress, uint32(olderRoundID), uint(1)). + On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, uint32(olderRoundID), uint(1)). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{ PipelineRunID: corenull.NewInt64(int64(1), true), Aggregator: contractAddress, @@ -1795,16 +1787,16 @@ func TestFluxMonitor_DoesNotDoubleSubmit(t *testing.T) { tm.orm. On("UpdateFluxMonitorRoundStats", + mock.Anything, contractAddress, uint32(olderRoundID), int64(1), uint(1), - mock.Anything, ). Return(nil). Once() - tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + tm.logBroadcaster.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) fm.ExportedRespondToNewRoundLog(&flux_aggregator_wrapper.FluxAggregatorNewRound{ RoundId: big.NewInt(olderRoundID), StartedAt: big.NewInt(0), @@ -1824,6 +1816,7 @@ func TestFluxMonitor_DrumbeatTicker(t *testing.T) { fm, tm := setup(t, db, disablePollTicker(true), disableIdleTimer(true), enableDrumbeatTicker("@every 3s", 2*time.Second)) tm.keyStore.On("EnabledKeysForChain", mock.Anything, testutils.FixtureChainID).Return([]ethkey.KeyV2{{Address: nodeAddr}}, nil) + tm.orm.On("WithDataSource", mock.Anything).Return(fluxmonitorv2.ORM(tm.orm)) const fetchedAnswer = 100 answerBigInt := big.NewInt(fetchedAnswer) @@ -1853,7 +1846,7 @@ func TestFluxMonitor_DrumbeatTicker(t *testing.T) { Return(roundState, nil). Once() - tm.orm.On("FindOrCreateFluxMonitorRoundStats", contractAddress, roundID, mock.Anything). + tm.orm.On("FindOrCreateFluxMonitorRoundStats", mock.Anything, contractAddress, roundID, mock.Anything). Return(fluxmonitorv2.FluxMonitorRoundStatsV2{Aggregator: contractAddress, RoundID: roundID}, nil). Once() @@ -1895,7 +1888,7 @@ func TestFluxMonitor_DrumbeatTicker(t *testing.T) { tm.pipelineRunner.On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = runID + args.Get(2).(*pipeline.Run).ID = runID }). Once() tm.contractSubmitter. @@ -1904,7 +1897,7 @@ func TestFluxMonitor_DrumbeatTicker(t *testing.T) { Once() tm.orm. - On("UpdateFluxMonitorRoundStats", contractAddress, roundID, runID, mock.Anything, mock.Anything). + On("UpdateFluxMonitorRoundStats", mock.Anything, contractAddress, roundID, runID, mock.Anything). Return(nil). Once() } diff --git a/core/services/fluxmonitorv2/mocks/orm.go b/core/services/fluxmonitorv2/mocks/orm.go index 287c7ebb5fa..e5173db8264 100644 --- a/core/services/fluxmonitorv2/mocks/orm.go +++ b/core/services/fluxmonitorv2/mocks/orm.go @@ -11,7 +11,7 @@ import ( mock "github.com/stretchr/testify/mock" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" + sqlutil "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" ) // ORM is an autogenerated mock type for the ORM type @@ -19,9 +19,9 @@ type ORM struct { mock.Mock } -// CountFluxMonitorRoundStats provides a mock function with given fields: -func (_m *ORM) CountFluxMonitorRoundStats() (int, error) { - ret := _m.Called() +// CountFluxMonitorRoundStats provides a mock function with given fields: ctx +func (_m *ORM) CountFluxMonitorRoundStats(ctx context.Context) (int, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for CountFluxMonitorRoundStats") @@ -29,17 +29,17 @@ func (_m *ORM) CountFluxMonitorRoundStats() (int, error) { var r0 int var r1 error - if rf, ok := ret.Get(0).(func() (int, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) (int, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() int); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) int); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(int) } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -65,17 +65,17 @@ func (_m *ORM) CreateEthTransaction(ctx context.Context, fromAddress common.Addr return r0 } -// DeleteFluxMonitorRoundsBackThrough provides a mock function with given fields: aggregator, roundID -func (_m *ORM) DeleteFluxMonitorRoundsBackThrough(aggregator common.Address, roundID uint32) error { - ret := _m.Called(aggregator, roundID) +// DeleteFluxMonitorRoundsBackThrough provides a mock function with given fields: ctx, aggregator, roundID +func (_m *ORM) DeleteFluxMonitorRoundsBackThrough(ctx context.Context, aggregator common.Address, roundID uint32) error { + ret := _m.Called(ctx, aggregator, roundID) if len(ret) == 0 { panic("no return value specified for DeleteFluxMonitorRoundsBackThrough") } var r0 error - if rf, ok := ret.Get(0).(func(common.Address, uint32) error); ok { - r0 = rf(aggregator, roundID) + if rf, ok := ret.Get(0).(func(context.Context, common.Address, uint32) error); ok { + r0 = rf(ctx, aggregator, roundID) } else { r0 = ret.Error(0) } @@ -83,9 +83,9 @@ func (_m *ORM) DeleteFluxMonitorRoundsBackThrough(aggregator common.Address, rou return r0 } -// FindOrCreateFluxMonitorRoundStats provides a mock function with given fields: aggregator, roundID, newRoundLogs -func (_m *ORM) FindOrCreateFluxMonitorRoundStats(aggregator common.Address, roundID uint32, newRoundLogs uint) (fluxmonitorv2.FluxMonitorRoundStatsV2, error) { - ret := _m.Called(aggregator, roundID, newRoundLogs) +// FindOrCreateFluxMonitorRoundStats provides a mock function with given fields: ctx, aggregator, roundID, newRoundLogs +func (_m *ORM) FindOrCreateFluxMonitorRoundStats(ctx context.Context, aggregator common.Address, roundID uint32, newRoundLogs uint) (fluxmonitorv2.FluxMonitorRoundStatsV2, error) { + ret := _m.Called(ctx, aggregator, roundID, newRoundLogs) if len(ret) == 0 { panic("no return value specified for FindOrCreateFluxMonitorRoundStats") @@ -93,17 +93,17 @@ func (_m *ORM) FindOrCreateFluxMonitorRoundStats(aggregator common.Address, roun var r0 fluxmonitorv2.FluxMonitorRoundStatsV2 var r1 error - if rf, ok := ret.Get(0).(func(common.Address, uint32, uint) (fluxmonitorv2.FluxMonitorRoundStatsV2, error)); ok { - return rf(aggregator, roundID, newRoundLogs) + if rf, ok := ret.Get(0).(func(context.Context, common.Address, uint32, uint) (fluxmonitorv2.FluxMonitorRoundStatsV2, error)); ok { + return rf(ctx, aggregator, roundID, newRoundLogs) } - if rf, ok := ret.Get(0).(func(common.Address, uint32, uint) fluxmonitorv2.FluxMonitorRoundStatsV2); ok { - r0 = rf(aggregator, roundID, newRoundLogs) + if rf, ok := ret.Get(0).(func(context.Context, common.Address, uint32, uint) fluxmonitorv2.FluxMonitorRoundStatsV2); ok { + r0 = rf(ctx, aggregator, roundID, newRoundLogs) } else { r0 = ret.Get(0).(fluxmonitorv2.FluxMonitorRoundStatsV2) } - if rf, ok := ret.Get(1).(func(common.Address, uint32, uint) error); ok { - r1 = rf(aggregator, roundID, newRoundLogs) + if rf, ok := ret.Get(1).(func(context.Context, common.Address, uint32, uint) error); ok { + r1 = rf(ctx, aggregator, roundID, newRoundLogs) } else { r1 = ret.Error(1) } @@ -111,9 +111,9 @@ func (_m *ORM) FindOrCreateFluxMonitorRoundStats(aggregator common.Address, roun return r0, r1 } -// MostRecentFluxMonitorRoundID provides a mock function with given fields: aggregator -func (_m *ORM) MostRecentFluxMonitorRoundID(aggregator common.Address) (uint32, error) { - ret := _m.Called(aggregator) +// MostRecentFluxMonitorRoundID provides a mock function with given fields: ctx, aggregator +func (_m *ORM) MostRecentFluxMonitorRoundID(ctx context.Context, aggregator common.Address) (uint32, error) { + ret := _m.Called(ctx, aggregator) if len(ret) == 0 { panic("no return value specified for MostRecentFluxMonitorRoundID") @@ -121,17 +121,17 @@ func (_m *ORM) MostRecentFluxMonitorRoundID(aggregator common.Address) (uint32, var r0 uint32 var r1 error - if rf, ok := ret.Get(0).(func(common.Address) (uint32, error)); ok { - return rf(aggregator) + if rf, ok := ret.Get(0).(func(context.Context, common.Address) (uint32, error)); ok { + return rf(ctx, aggregator) } - if rf, ok := ret.Get(0).(func(common.Address) uint32); ok { - r0 = rf(aggregator) + if rf, ok := ret.Get(0).(func(context.Context, common.Address) uint32); ok { + r0 = rf(ctx, aggregator) } else { r0 = ret.Get(0).(uint32) } - if rf, ok := ret.Get(1).(func(common.Address) error); ok { - r1 = rf(aggregator) + if rf, ok := ret.Get(1).(func(context.Context, common.Address) error); ok { + r1 = rf(ctx, aggregator) } else { r1 = ret.Error(1) } @@ -139,24 +139,17 @@ func (_m *ORM) MostRecentFluxMonitorRoundID(aggregator common.Address) (uint32, return r0, r1 } -// UpdateFluxMonitorRoundStats provides a mock function with given fields: aggregator, roundID, runID, newRoundLogsAddition, qopts -func (_m *ORM) UpdateFluxMonitorRoundStats(aggregator common.Address, roundID uint32, runID int64, newRoundLogsAddition uint, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, aggregator, roundID, runID, newRoundLogsAddition) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// UpdateFluxMonitorRoundStats provides a mock function with given fields: ctx, aggregator, roundID, runID, newRoundLogsAddition +func (_m *ORM) UpdateFluxMonitorRoundStats(ctx context.Context, aggregator common.Address, roundID uint32, runID int64, newRoundLogsAddition uint) error { + ret := _m.Called(ctx, aggregator, roundID, runID, newRoundLogsAddition) if len(ret) == 0 { panic("no return value specified for UpdateFluxMonitorRoundStats") } var r0 error - if rf, ok := ret.Get(0).(func(common.Address, uint32, int64, uint, ...pg.QOpt) error); ok { - r0 = rf(aggregator, roundID, runID, newRoundLogsAddition, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, common.Address, uint32, int64, uint) error); ok { + r0 = rf(ctx, aggregator, roundID, runID, newRoundLogsAddition) } else { r0 = ret.Error(0) } @@ -164,6 +157,26 @@ func (_m *ORM) UpdateFluxMonitorRoundStats(aggregator common.Address, roundID ui return r0 } +// WithDataSource provides a mock function with given fields: _a0 +func (_m *ORM) WithDataSource(_a0 sqlutil.DataSource) fluxmonitorv2.ORM { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for WithDataSource") + } + + var r0 fluxmonitorv2.ORM + if rf, ok := ret.Get(0).(func(sqlutil.DataSource) fluxmonitorv2.ORM); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(fluxmonitorv2.ORM) + } + } + + return r0 +} + // NewORM creates a new instance of ORM. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewORM(t interface { diff --git a/core/services/fluxmonitorv2/orm.go b/core/services/fluxmonitorv2/orm.go index 91973387e32..e090b84ed04 100644 --- a/core/services/fluxmonitorv2/orm.go +++ b/core/services/fluxmonitorv2/orm.go @@ -7,12 +7,10 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" - "github.com/jmoiron/sqlx" - + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/common/txmgr/types" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) type transmitter interface { @@ -23,48 +21,49 @@ type transmitter interface { // ORM defines an interface for database commands related to Flux Monitor v2 type ORM interface { - MostRecentFluxMonitorRoundID(aggregator common.Address) (uint32, error) - DeleteFluxMonitorRoundsBackThrough(aggregator common.Address, roundID uint32) error - FindOrCreateFluxMonitorRoundStats(aggregator common.Address, roundID uint32, newRoundLogs uint) (FluxMonitorRoundStatsV2, error) - UpdateFluxMonitorRoundStats(aggregator common.Address, roundID uint32, runID int64, newRoundLogsAddition uint, qopts ...pg.QOpt) error + MostRecentFluxMonitorRoundID(ctx context.Context, aggregator common.Address) (uint32, error) + DeleteFluxMonitorRoundsBackThrough(ctx context.Context, aggregator common.Address, roundID uint32) error + FindOrCreateFluxMonitorRoundStats(ctx context.Context, aggregator common.Address, roundID uint32, newRoundLogs uint) (FluxMonitorRoundStatsV2, error) + UpdateFluxMonitorRoundStats(ctx context.Context, aggregator common.Address, roundID uint32, runID int64, newRoundLogsAddition uint) error CreateEthTransaction(ctx context.Context, fromAddress, toAddress common.Address, payload []byte, gasLimit uint64, idempotencyKey *string) error - CountFluxMonitorRoundStats() (count int, err error) + CountFluxMonitorRoundStats(ctx context.Context) (count int, err error) + + WithDataSource(sqlutil.DataSource) ORM } type orm struct { - q pg.Q + ds sqlutil.DataSource txm transmitter strategy types.TxStrategy checker txmgr.TransmitCheckerSpec logger logger.Logger } +func (o *orm) WithDataSource(ds sqlutil.DataSource) ORM { return o.withDataSource(ds) } + +func (o *orm) withDataSource(ds sqlutil.DataSource) *orm { + return &orm{ds, o.txm, o.strategy, o.checker, o.logger} +} + // NewORM initializes a new ORM -func NewORM(db *sqlx.DB, lggr logger.Logger, cfg pg.QConfig, txm transmitter, strategy types.TxStrategy, checker txmgr.TransmitCheckerSpec) ORM { +func NewORM(ds sqlutil.DataSource, lggr logger.Logger, txm transmitter, strategy types.TxStrategy, checker txmgr.TransmitCheckerSpec) ORM { namedLogger := lggr.Named("FluxMonitorORM") - q := pg.NewQ(db, namedLogger, cfg) - return &orm{ - q, - txm, - strategy, - checker, - namedLogger, - } + return &orm{ds, txm, strategy, checker, namedLogger} } // MostRecentFluxMonitorRoundID finds roundID of the most recent round that the // provided oracle address submitted to -func (o *orm) MostRecentFluxMonitorRoundID(aggregator common.Address) (uint32, error) { +func (o *orm) MostRecentFluxMonitorRoundID(ctx context.Context, aggregator common.Address) (uint32, error) { var stats FluxMonitorRoundStatsV2 - err := o.q.Get(&stats, `SELECT * FROM flux_monitor_round_stats_v2 WHERE aggregator = $1 ORDER BY round_id DESC LIMIT 1`, aggregator) + err := o.ds.GetContext(ctx, &stats, `SELECT * FROM flux_monitor_round_stats_v2 WHERE aggregator = $1 ORDER BY round_id DESC LIMIT 1`, aggregator) return stats.RoundID, errors.Wrap(err, "MostRecentFluxMonitorRoundID failed") } // DeleteFluxMonitorRoundsBackThrough deletes all the RoundStat records for a // given oracle address starting from the most recent round back through the // given round -func (o *orm) DeleteFluxMonitorRoundsBackThrough(aggregator common.Address, roundID uint32) error { - _, err := o.q.Exec(` +func (o *orm) DeleteFluxMonitorRoundsBackThrough(ctx context.Context, aggregator common.Address, roundID uint32) error { + _, err := o.ds.ExecContext(ctx, ` DELETE FROM flux_monitor_round_stats_v2 WHERE aggregator = $1 AND round_id >= $2 @@ -74,14 +73,14 @@ func (o *orm) DeleteFluxMonitorRoundsBackThrough(aggregator common.Address, roun // FindOrCreateFluxMonitorRoundStats find the round stats record for a given // oracle on a given round, or creates it if no record exists -func (o *orm) FindOrCreateFluxMonitorRoundStats(aggregator common.Address, roundID uint32, newRoundLogs uint) (stats FluxMonitorRoundStatsV2, err error) { - err = o.q.Transaction(func(tx pg.Queryer) error { - err = tx.Get(&stats, +func (o *orm) FindOrCreateFluxMonitorRoundStats(ctx context.Context, aggregator common.Address, roundID uint32, newRoundLogs uint) (stats FluxMonitorRoundStatsV2, err error) { + err = sqlutil.Transact(ctx, o.withDataSource, o.ds, nil, func(tx *orm) error { + err = tx.ds.GetContext(ctx, &stats, `INSERT INTO flux_monitor_round_stats_v2 (aggregator, round_id, num_new_round_logs, num_submissions) VALUES ($1, $2, $3, 0) ON CONFLICT (aggregator, round_id) DO NOTHING`, aggregator, roundID, newRoundLogs) if errors.Is(err, sql.ErrNoRows) { - err = tx.Get(&stats, `SELECT * FROM flux_monitor_round_stats_v2 WHERE aggregator=$1 AND round_id=$2`, aggregator, roundID) + err = tx.ds.GetContext(ctx, &stats, `SELECT * FROM flux_monitor_round_stats_v2 WHERE aggregator=$1 AND round_id=$2`, aggregator, roundID) } return err }) @@ -91,9 +90,8 @@ func (o *orm) FindOrCreateFluxMonitorRoundStats(aggregator common.Address, round // UpdateFluxMonitorRoundStats trys to create a RoundStat record for the given oracle // at the given round. If one already exists, it increments the num_submissions column. -func (o *orm) UpdateFluxMonitorRoundStats(aggregator common.Address, roundID uint32, runID int64, newRoundLogsAddition uint, qopts ...pg.QOpt) error { - q := o.q.WithOpts(qopts...) - err := q.ExecQ(` +func (o *orm) UpdateFluxMonitorRoundStats(ctx context.Context, aggregator common.Address, roundID uint32, runID int64, newRoundLogsAddition uint) error { + _, err := o.ds.ExecContext(ctx, ` INSERT INTO flux_monitor_round_stats_v2 ( aggregator, round_id, pipeline_run_id, num_new_round_logs, num_submissions ) VALUES ( @@ -108,8 +106,8 @@ func (o *orm) UpdateFluxMonitorRoundStats(aggregator common.Address, roundID uin } // CountFluxMonitorRoundStats counts the total number of records -func (o *orm) CountFluxMonitorRoundStats() (count int, err error) { - err = o.q.Get(&count, `SELECT count(*) FROM flux_monitor_round_stats_v2`) +func (o *orm) CountFluxMonitorRoundStats(ctx context.Context) (count int, err error) { + err = o.ds.GetContext(ctx, &count, `SELECT count(*) FROM flux_monitor_round_stats_v2`) return count, errors.Wrap(err, "CountFluxMonitorRoundStats failed") } diff --git a/core/services/fluxmonitorv2/orm_test.go b/core/services/fluxmonitorv2/orm_test.go index 9b31525831b..f6904b9fe97 100644 --- a/core/services/fluxmonitorv2/orm_test.go +++ b/core/services/fluxmonitorv2/orm_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink-common/pkg/utils/jsonserializable" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" commontxmmocks "github.com/smartcontractkit/chainlink/v2/common/txmgr/types/mocks" "github.com/smartcontractkit/chainlink/v2/core/bridges" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" @@ -28,61 +29,62 @@ import ( func TestORM_MostRecentFluxMonitorRoundID(t *testing.T) { t.Parallel() + ctx := tests.Context(t) db := pgtest.NewSqlxDB(t) - cfg := pgtest.NewQConfig(true) - orm := newORM(t, db, cfg, nil) + orm := newORM(t, db, nil) address := testutils.NewAddress() // Setup the rounds for round := uint32(0); round < 10; round++ { - _, err := orm.FindOrCreateFluxMonitorRoundStats(address, round, 1) + _, err := orm.FindOrCreateFluxMonitorRoundStats(ctx, address, round, 1) require.NoError(t, err) } - count, err := orm.CountFluxMonitorRoundStats() + count, err := orm.CountFluxMonitorRoundStats(ctx) require.NoError(t, err) require.Equal(t, 10, count) // Ensure round stats are not created again for the same address/roundID - stats, err := orm.FindOrCreateFluxMonitorRoundStats(address, uint32(0), 1) + stats, err := orm.FindOrCreateFluxMonitorRoundStats(ctx, address, uint32(0), 1) require.NoError(t, err) require.Equal(t, uint32(0), stats.RoundID) require.Equal(t, address, stats.Aggregator) require.Equal(t, uint64(1), stats.NumNewRoundLogs) - count, err = orm.CountFluxMonitorRoundStats() + count, err = orm.CountFluxMonitorRoundStats(ctx) require.NoError(t, err) require.Equal(t, 10, count) - roundID, err := orm.MostRecentFluxMonitorRoundID(testutils.NewAddress()) + roundID, err := orm.MostRecentFluxMonitorRoundID(ctx, testutils.NewAddress()) require.Error(t, err) require.Equal(t, uint32(0), roundID) - roundID, err = orm.MostRecentFluxMonitorRoundID(address) + roundID, err = orm.MostRecentFluxMonitorRoundID(ctx, address) require.NoError(t, err) require.Equal(t, uint32(9), roundID) // Deleting rounds against a new address should incur no changes - err = orm.DeleteFluxMonitorRoundsBackThrough(testutils.NewAddress(), 5) + err = orm.DeleteFluxMonitorRoundsBackThrough(ctx, testutils.NewAddress(), 5) require.NoError(t, err) - count, err = orm.CountFluxMonitorRoundStats() + count, err = orm.CountFluxMonitorRoundStats(ctx) require.NoError(t, err) require.Equal(t, 10, count) // Deleting rounds against the address - err = orm.DeleteFluxMonitorRoundsBackThrough(address, 5) + err = orm.DeleteFluxMonitorRoundsBackThrough(ctx, address, 5) require.NoError(t, err) - count, err = orm.CountFluxMonitorRoundStats() + count, err = orm.CountFluxMonitorRoundStats(ctx) require.NoError(t, err) require.Equal(t, 5, count) } func TestORM_UpdateFluxMonitorRoundStats(t *testing.T) { t.Parallel() + ctx := tests.Context(t) cfg := configtest.NewGeneralConfig(t, nil) db := pgtest.NewSqlxDB(t) @@ -92,13 +94,13 @@ func TestORM_UpdateFluxMonitorRoundStats(t *testing.T) { // Instantiate a real pipeline ORM because we need to create a pipeline run // for the foreign key constraint of the stats record - pipelineORM := pipeline.NewORM(db, lggr, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, lggr, cfg.JobPipeline().MaxSuccessfulRuns()) bridgeORM := bridges.NewORM(db) // Instantiate a real job ORM because we need to create a job to satisfy // a check in pipeline.CreateRun jobORM := job.NewORM(db, pipelineORM, bridgeORM, keyStore, lggr, cfg.Database()) - orm := newORM(t, db, cfg.Database(), nil) + orm := newORM(t, db, nil) address := testutils.NewAddress() var roundID uint32 = 1 @@ -129,13 +131,13 @@ func TestORM_UpdateFluxMonitorRoundStats(t *testing.T) { }, }, } - err := pipelineORM.InsertFinishedRun(run, true) + err := pipelineORM.InsertFinishedRun(ctx, run, true) require.NoError(t, err) - err = orm.UpdateFluxMonitorRoundStats(address, roundID, run.ID, 0) + err = orm.UpdateFluxMonitorRoundStats(ctx, address, roundID, run.ID, 0) require.NoError(t, err) - stats, err := orm.FindOrCreateFluxMonitorRoundStats(address, roundID, 0) + stats, err := orm.FindOrCreateFluxMonitorRoundStats(ctx, address, roundID, 0) require.NoError(t, err) require.Equal(t, expectedCount, stats.NumSubmissions) require.True(t, stats.PipelineRunID.Valid) @@ -177,7 +179,7 @@ func TestORM_CreateEthTransaction(t *testing.T) { var ( txm = txmmocks.NewMockEvmTxManager(t) - orm = fluxmonitorv2.NewORM(db, logger.TestLogger(t), cfg, txm, strategy, txmgr.TransmitCheckerSpec{}) + orm = fluxmonitorv2.NewORM(db, logger.TestLogger(t), txm, strategy, txmgr.TransmitCheckerSpec{}) _, from = cltest.MustInsertRandomKey(t, ethKeyStore) to = testutils.NewAddress() diff --git a/core/services/gateway/delegate.go b/core/services/gateway/delegate.go index ba34f2894de..8cddc027803 100644 --- a/core/services/gateway/delegate.go +++ b/core/services/gateway/delegate.go @@ -41,10 +41,10 @@ func (d *Delegate) JobType() job.Type { return job.Gateway } -func (d *Delegate) BeforeJobCreated(spec job.Job) {} -func (d *Delegate) AfterJobCreated(spec job.Job) {} -func (d *Delegate) BeforeJobDeleted(spec job.Job) {} -func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) BeforeJobCreated(spec job.Job) {} +func (d *Delegate) AfterJobCreated(spec job.Job) {} +func (d *Delegate) BeforeJobDeleted(spec job.Job) {} +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // ServicesForSpec returns the scheduler to be used for running observer jobs func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) (services []job.ServiceCtx, err error) { diff --git a/core/services/job/job_orm_test.go b/core/services/job/job_orm_test.go index a6e3622df1b..c60f096c358 100644 --- a/core/services/job/job_orm_test.go +++ b/core/services/job/job_orm_test.go @@ -82,7 +82,7 @@ func TestORM(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) require.NoError(t, keyStore.P2P().Add(cltest.DefaultP2PKey)) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -346,7 +346,7 @@ func TestORM_DeleteJob_DeletesAssociatedRecords(t *testing.T) { require.NoError(t, keyStore.P2P().Add(cltest.DefaultP2PKey)) lggr := logger.TestLogger(t) - pipelineORM := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) korm := keeper.NewORM(db, logger.TestLogger(t)) @@ -444,7 +444,7 @@ func TestORM_CreateJob_VRFV2(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) lggr := logger.TestLogger(t) - pipelineORM := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -528,7 +528,7 @@ func TestORM_CreateJob_VRFV2Plus(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) lggr := logger.TestLogger(t) - pipelineORM := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -615,7 +615,7 @@ func TestORM_CreateJob_OCRBootstrap(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) lggr := logger.TestLogger(t) - pipelineORM := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -641,7 +641,7 @@ func TestORM_CreateJob_EVMChainID_Validation(t *testing.T) { keyStore := cltest.NewKeyStore(t, db, config.Database()) lggr := logger.TestLogger(t) - pipelineORM := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -736,7 +736,7 @@ func TestORM_CreateJob_OCR_DuplicatedContractAddress(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) lggr := logger.TestLogger(t) - pipelineORM := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -805,7 +805,7 @@ func TestORM_CreateJob_OCR2_DuplicatedContractAddress(t *testing.T) { require.NoError(t, keyStore.OCR2().Add(cltest.DefaultOCR2Key)) lggr := logger.TestLogger(t) - pipelineORM := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -866,7 +866,7 @@ func TestORM_CreateJob_OCR2_Sending_Keys_Transmitter_Keys_Validations(t *testing require.NoError(t, keyStore.OCR2().Add(cltest.DefaultOCR2Key)) lggr := logger.TestLogger(t) - pipelineORM := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -986,7 +986,7 @@ func Test_FindJobs(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) require.NoError(t, keyStore.P2P().Add(cltest.DefaultP2PKey)) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -1067,7 +1067,7 @@ func Test_FindJob(t *testing.T) { require.NoError(t, keyStore.P2P().Add(cltest.DefaultP2PKey)) require.NoError(t, keyStore.CSA().Add(cltest.DefaultCSAKey)) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -1250,7 +1250,7 @@ func Test_FindJobsByPipelineSpecIDs(t *testing.T) { keyStore := cltest.NewKeyStore(t, db, config.Database()) require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -1298,7 +1298,7 @@ func Test_FindPipelineRuns(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) require.NoError(t, keyStore.P2P().Add(cltest.DefaultP2PKey)) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: config, KeyStore: keyStore.Eth()}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) @@ -1359,7 +1359,7 @@ func Test_PipelineRunsByJobID(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) require.NoError(t, keyStore.P2P().Add(cltest.DefaultP2PKey)) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: config, KeyStore: keyStore.Eth()}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) @@ -1419,7 +1419,7 @@ func Test_FindPipelineRunIDsByJobID(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) require.NoError(t, keyStore.P2P().Add(cltest.DefaultP2PKey)) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: config, KeyStore: keyStore.Eth()}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) @@ -1527,7 +1527,7 @@ func Test_FindPipelineRunsByIDs(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) require.NoError(t, keyStore.P2P().Add(cltest.DefaultP2PKey)) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: config, KeyStore: keyStore.Eth()}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) @@ -1585,7 +1585,7 @@ func Test_FindPipelineRunByID(t *testing.T) { err := keyStore.OCR().Add(cltest.DefaultOCRKey) require.NoError(t, err) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -1628,7 +1628,7 @@ func Test_FindJobWithoutSpecErrors(t *testing.T) { err := keyStore.OCR().Add(cltest.DefaultOCRKey) require.NoError(t, err) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -1665,7 +1665,7 @@ func Test_FindSpecErrorsByJobIDs(t *testing.T) { err := keyStore.OCR().Add(cltest.DefaultOCRKey) require.NoError(t, err) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) orm := NewTestORM(t, db, pipelineORM, bridgesORM, keyStore, config.Database()) @@ -1699,7 +1699,7 @@ func Test_CountPipelineRunsByJobID(t *testing.T) { require.NoError(t, keyStore.OCR().Add(cltest.DefaultOCRKey)) require.NoError(t, keyStore.P2P().Add(cltest.DefaultP2PKey)) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{DB: db, GeneralConfig: config, KeyStore: keyStore.Eth()}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) @@ -1740,6 +1740,7 @@ func Test_CountPipelineRunsByJobID(t *testing.T) { func mustInsertPipelineRun(t *testing.T, orm pipeline.ORM, j job.Job) pipeline.Run { t.Helper() + ctx := testutils.Context(t) run := pipeline.Run{ PipelineSpecID: j.PipelineSpecID, @@ -1750,7 +1751,7 @@ func mustInsertPipelineRun(t *testing.T, orm pipeline.ORM, j job.Job) pipeline.R CreatedAt: time.Now(), FinishedAt: null.Time{}, } - err := orm.CreateRun(&run) + err := orm.CreateRun(ctx, &run) require.NoError(t, err) return run } diff --git a/core/services/job/job_pipeline_orm_integration_test.go b/core/services/job/job_pipeline_orm_integration_test.go index 698e60eca7b..696005c270e 100644 --- a/core/services/job/job_pipeline_orm_integration_test.go +++ b/core/services/job/job_pipeline_orm_integration_test.go @@ -126,14 +126,15 @@ func TestPipelineORM_Integration(t *testing.T) { _, bridge2 := cltest.MustCreateBridge(t, db, cltest.BridgeOpts{}) t.Run("creates task DAGs", func(t *testing.T) { + ctx := testutils.Context(t) clearJobsDb(t, db) - orm := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + orm := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) p, err := pipeline.Parse(DotStr) require.NoError(t, err) - specID, err = orm.CreateSpec(*p, models.Interval(0)) + specID, err = orm.CreateSpec(ctx, nil, *p, models.Interval(0)) require.NoError(t, err) var pipelineSpecs []pipeline.Spec @@ -152,7 +153,7 @@ func TestPipelineORM_Integration(t *testing.T) { lggr := logger.TestLogger(t) cfg := configtest.NewTestGeneralConfig(t) clearJobsDb(t, db) - orm := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + orm := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) btORM := bridges.NewORM(db) relayExtenders := evmtest.NewChainRelayExtenders(t, evmtest.TestChainOpts{Client: evmtest.NewEthClientMockWithDefaultChain(t), DB: db, GeneralConfig: config, KeyStore: ethKeyStore}) legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) diff --git a/core/services/job/kv_orm_test.go b/core/services/job/kv_orm_test.go index 6a3269e9992..3ba03b8bc3c 100644 --- a/core/services/job/kv_orm_test.go +++ b/core/services/job/kv_orm_test.go @@ -29,7 +29,7 @@ func TestJobKVStore(t *testing.T) { lggr := logger.TestLogger(t) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobID := int32(1337) diff --git a/core/services/job/orm.go b/core/services/job/orm.go index 4ac1e7c6047..9d2a6545163 100644 --- a/core/services/job/orm.go +++ b/core/services/job/orm.go @@ -456,7 +456,7 @@ func (o *orm) CreateJob(jb *Job, qopts ...pg.QOpt) error { o.lggr.Panicf("Unsupported jb.Type: %v", jb.Type) } - pipelineSpecID, err := o.pipelineORM.CreateSpec(p, jb.MaxTaskDuration, pg.WithQueryer(tx)) + pipelineSpecID, err := o.pipelineORM.CreateSpec(ctx, tx, p, jb.MaxTaskDuration) if err != nil { return errors.Wrap(err, "failed to create pipeline spec") } diff --git a/core/services/job/runner_integration_test.go b/core/services/job/runner_integration_test.go index ed2950ac382..6149bb71cf6 100644 --- a/core/services/job/runner_integration_test.go +++ b/core/services/job/runner_integration_test.go @@ -80,7 +80,7 @@ func TestRunner(t *testing.T) { ethClient.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Maybe().Return(nil, nil) ctx := testutils.Context(t) - pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger.TestLogger(t), config.JobPipeline().MaxSuccessfulRuns()) require.NoError(t, pipelineORM.Start(ctx)) t.Cleanup(func() { assert.NoError(t, pipelineORM.Close()) }) btORM := bridges.NewORM(db) @@ -888,7 +888,7 @@ func TestRunner_Success_Callback_AsyncJob(t *testing.T) { _ = cltest.CreateJobRunViaExternalInitiatorV2(t, app, jobUUID, *eia, cltest.MustJSONMarshal(t, eiRequest)) - pipelineORM := pipeline.NewORM(app.GetSqlxDB(), logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(app.GetSqlxDB(), logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(app.GetSqlxDB()) jobORM := NewTestORM(t, app.GetSqlxDB(), pipelineORM, bridgesORM, app.KeyStore, cfg.Database()) @@ -1065,7 +1065,7 @@ func TestRunner_Error_Callback_AsyncJob(t *testing.T) { t.Run("simulate request from EI -> Core node with erroring callback", func(t *testing.T) { _ = cltest.CreateJobRunViaExternalInitiatorV2(t, app, jobUUID, *eia, cltest.MustJSONMarshal(t, eiRequest)) - pipelineORM := pipeline.NewORM(app.GetSqlxDB(), logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(app.GetSqlxDB(), logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(app.GetSqlxDB()) jobORM := NewTestORM(t, app.GetSqlxDB(), pipelineORM, bridgesORM, app.KeyStore, cfg.Database()) diff --git a/core/services/job/spawner.go b/core/services/job/spawner.go index 3d30a3190b3..8024424226c 100644 --- a/core/services/job/spawner.go +++ b/core/services/job/spawner.go @@ -78,7 +78,7 @@ type ( // non-db side effects. This is required in order to guarantee mutual atomicity between // all tasks intended to happen during job deletion. For the same reason, the job will // not show up in the db within OnDeleteJob(), even though it is still actively running. - OnDeleteJob(ctx context.Context, jb Job, q pg.Queryer) error + OnDeleteJob(ctx context.Context, jb Job) error } activeJob struct { @@ -340,7 +340,7 @@ func (js *spawner) DeleteJob(jobID int32, qopts ...pg.QOpt) error { // we know the DELETE will succeed. The DELETE will be finalized only if all db transactions in OnDeleteJob() // succeed. If either of those fails, the job will not be stopped and everything will be rolled back. lggr.Debugw("Callback: OnDeleteJob") - err = aj.delegate.OnDeleteJob(ctx, aj.spec, tx) + err = aj.delegate.OnDeleteJob(ctx, aj.spec) if err != nil { return err } @@ -395,7 +395,9 @@ func (n *NullDelegate) ServicesForSpec(ctx context.Context, spec Job) (s []Servi return } -func (n *NullDelegate) BeforeJobCreated(spec Job) {} -func (n *NullDelegate) AfterJobCreated(spec Job) {} -func (n *NullDelegate) BeforeJobDeleted(spec Job) {} -func (n *NullDelegate) OnDeleteJob(ctx context.Context, spec Job, q pg.Queryer) error { return nil } +func (n *NullDelegate) BeforeJobCreated(spec Job) {} +func (n *NullDelegate) AfterJobCreated(spec Job) {} +func (n *NullDelegate) BeforeJobDeleted(spec Job) {} +func (n *NullDelegate) OnDeleteJob(context.Context, Job) error { + return nil +} diff --git a/core/services/job/spawner_test.go b/core/services/job/spawner_test.go index d2e7a80d5d4..802763cfaab 100644 --- a/core/services/job/spawner_test.go +++ b/core/services/job/spawner_test.go @@ -100,7 +100,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { legacyChains := evmrelay.NewLegacyChainsFromRelayerExtenders(relayExtenders) t.Run("should respect its dependents", func(t *testing.T) { lggr := logger.TestLogger(t) - orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) + orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) a := utils.NewDependentAwaiter() a.AddDependents(1) spawner := job.NewSpawner(orm, config.Database(), noopChecker{}, map[job.Type]job.Delegate{}, db, lggr, []utils.DependentAwaiter{a}) @@ -123,7 +123,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { jobB := makeOCRJobSpec(t, address, bridge.Name.String(), bridge2.Name.String()) lggr := logger.TestLogger(t) - orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) + orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) eventuallyA := cltest.NewAwaiter() serviceA1 := mocks.NewServiceCtx(t) @@ -188,7 +188,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { serviceA2.On("Start", mock.Anything).Return(nil).Once().Run(func(mock.Arguments) { eventually.ItHappened() }) lggr := logger.TestLogger(t) - orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) + orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) mailMon := servicetest.Run(t, mailboxtest.NewMonitor(t)) d := ocr.NewDelegate(nil, orm, nil, nil, nil, monitoringEndpoint, legacyChains, logger.TestLogger(t), config.Database(), mailMon) delegateA := &delegate{jobA.Type, []job.ServiceCtx{serviceA1, serviceA2}, 0, nil, d} @@ -222,7 +222,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { serviceA2.On("Start", mock.Anything).Return(nil).Once().Run(func(mock.Arguments) { eventuallyStart.ItHappened() }) lggr := logger.TestLogger(t) - orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) + orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) mailMon := servicetest.Run(t, mailboxtest.NewMonitor(t)) d := ocr.NewDelegate(nil, orm, nil, nil, nil, monitoringEndpoint, legacyChains, logger.TestLogger(t), config.Database(), mailMon) delegateA := &delegate{jobA.Type, []job.ServiceCtx{serviceA1, serviceA2}, 0, nil, d} @@ -300,7 +300,7 @@ func TestSpawner_CreateJobDeleteJob(t *testing.T) { jobOCR2VRF := makeOCR2VRFJobSpec(t, keyStore, config, address, chain.ID(), 2) - orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) + orm := NewTestORM(t, db, pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()), bridges.NewORM(db), keyStore, config.Database()) mailMon := servicetest.Run(t, mailboxtest.NewMonitor(t)) processConfig := plugins.NewRegistrarConfig(loop.GRPCOpts{}, func(name string) (*plugins.RegisteredLoop, error) { return nil, nil }, func(loopId string) {}) diff --git a/core/services/keeper/delegate.go b/core/services/keeper/delegate.go index 679ccf3053d..9652434759b 100644 --- a/core/services/keeper/delegate.go +++ b/core/services/keeper/delegate.go @@ -11,7 +11,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/chains/legacyevm" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" ) @@ -51,10 +50,10 @@ func (d *Delegate) JobType() job.Type { return job.Keeper } -func (d *Delegate) BeforeJobCreated(spec job.Job) {} -func (d *Delegate) AfterJobCreated(spec job.Job) {} -func (d *Delegate) BeforeJobDeleted(spec job.Job) {} -func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) BeforeJobCreated(spec job.Job) {} +func (d *Delegate) AfterJobCreated(spec job.Job) {} +func (d *Delegate) BeforeJobDeleted(spec job.Job) {} +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // ServicesForSpec satisfies the job.Delegate interface. func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) (services []job.ServiceCtx, err error) { diff --git a/core/services/keeper/integration_test.go b/core/services/keeper/integration_test.go index 49073c8de56..08699d3d835 100644 --- a/core/services/keeper/integration_test.go +++ b/core/services/keeper/integration_test.go @@ -175,6 +175,7 @@ func TestKeeperEthIntegration(t *testing.T) { test := tt t.Run(test.name, func(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) g := gomega.NewWithT(t) // setup node key @@ -249,12 +250,12 @@ func TestKeeperEthIntegration(t *testing.T) { korm := keeper.NewORM(db, logger.TestLogger(t)) app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, backend.Backend(), nodeKey) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // create job regAddrEIP55 := evmtypes.EIP55AddressFromAddress(regAddr) job := cltest.MustInsertKeeperJob(t, db, korm, nodeAddressEIP55, regAddrEIP55) - err = app.JobSpawner().StartService(testutils.Context(t), job) + err = app.JobSpawner().StartService(ctx, job) require.NoError(t, err) // keeper job is triggered and payload is received @@ -311,7 +312,7 @@ func TestKeeperEthIntegration(t *testing.T) { cltest.AssertRecordEventually(t, app.GetSqlxDB(), ®istry, fmt.Sprintf("SELECT * FROM keeper_registries WHERE id = %d", registry.ID), func() bool { return registry.KeeperIndex == -1 }) - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) // Since we set grace period to 0, we can have more than 1 pipeline run per perform // This happens in case we start a pipeline run before previous perform tx is committed to chain @@ -481,6 +482,7 @@ func TestKeeperForwarderEthIntegration(t *testing.T) { func TestMaxPerformDataSize(t *testing.T) { t.Parallel() t.Run("max_perform_data_size_test", func(t *testing.T) { + ctx := testutils.Context(t) maxPerformDataSize := 1000 // Will be set as config override g := gomega.NewWithT(t) @@ -552,12 +554,12 @@ func TestMaxPerformDataSize(t *testing.T) { korm := keeper.NewORM(db, logger.TestLogger(t)) app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, backend.Backend(), nodeKey) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // create job regAddrEIP55 := evmtypes.EIP55AddressFromAddress(regAddr) job := cltest.MustInsertKeeperJob(t, db, korm, nodeAddressEIP55, regAddrEIP55) - err = app.JobSpawner().StartService(testutils.Context(t), job) + err = app.JobSpawner().StartService(ctx, job) require.NoError(t, err) // keeper job is triggered diff --git a/core/services/keeper/registry1_1_synchronizer_test.go b/core/services/keeper/registry1_1_synchronizer_test.go index 24a6a7288a7..61482208e5c 100644 --- a/core/services/keeper/registry1_1_synchronizer_test.go +++ b/core/services/keeper/registry1_1_synchronizer_test.go @@ -201,6 +201,7 @@ func Test_RegistrySynchronizer1_1_FullSync(t *testing.T) { } func Test_RegistrySynchronizer1_1_ConfigSetLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_1) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -235,11 +236,11 @@ func Test_RegistrySynchronizer1_1_ConfigSetLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.AssertRecordEventually(t, db, ®istry, fmt.Sprintf(`SELECT * FROM keeper_registries WHERE id = %d`, registry.ID), func() bool { return registry.BlockCountPerTurn == 40 @@ -248,6 +249,7 @@ func Test_RegistrySynchronizer1_1_ConfigSetLog(t *testing.T) { } func Test_RegistrySynchronizer1_1_KeepersUpdatedLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_1) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -281,11 +283,11 @@ func Test_RegistrySynchronizer1_1_KeepersUpdatedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.AssertRecordEventually(t, db, ®istry, fmt.Sprintf(`SELECT * FROM keeper_registries WHERE id = %d`, registry.ID), func() bool { return registry.NumKeepers == 2 @@ -293,6 +295,7 @@ func Test_RegistrySynchronizer1_1_KeepersUpdatedLog(t *testing.T) { cltest.AssertCount(t, db, "keeper_registries", 1) } func Test_RegistrySynchronizer1_1_UpkeepCanceledLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_1) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -320,16 +323,17 @@ func Test_RegistrySynchronizer1_1_UpkeepCanceledLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) } func Test_RegistrySynchronizer1_1_UpkeepRegisteredLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_1) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -360,16 +364,17 @@ func Test_RegistrySynchronizer1_1_UpkeepRegisteredLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) } func Test_RegistrySynchronizer1_1_UpkeepPerformedLog(t *testing.T) { + ctx := testutils.Context(t) g := gomega.NewWithT(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_1) @@ -401,11 +406,11 @@ func Test_RegistrySynchronizer1_1_UpkeepPerformedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) g.Eventually(func() int64 { var upkeep keeper.UpkeepRegistration diff --git a/core/services/keeper/registry1_2_synchronizer_test.go b/core/services/keeper/registry1_2_synchronizer_test.go index 23e6c0355ec..a62e27b8759 100644 --- a/core/services/keeper/registry1_2_synchronizer_test.go +++ b/core/services/keeper/registry1_2_synchronizer_test.go @@ -220,6 +220,7 @@ func Test_RegistrySynchronizer1_2_FullSync(t *testing.T) { } func Test_RegistrySynchronizer1_2_ConfigSetLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_2) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -258,11 +259,11 @@ func Test_RegistrySynchronizer1_2_ConfigSetLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.AssertRecordEventually(t, db, ®istry, fmt.Sprintf(`SELECT * FROM keeper_registries WHERE id = %d`, registry.ID), func() bool { return registry.BlockCountPerTurn == 40 @@ -271,6 +272,7 @@ func Test_RegistrySynchronizer1_2_ConfigSetLog(t *testing.T) { } func Test_RegistrySynchronizer1_2_KeepersUpdatedLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_2) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -308,11 +310,11 @@ func Test_RegistrySynchronizer1_2_KeepersUpdatedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.AssertRecordEventually(t, db, ®istry, fmt.Sprintf(`SELECT * FROM keeper_registries WHERE id = %d`, registry.ID), func() bool { return registry.NumKeepers == 2 @@ -321,6 +323,7 @@ func Test_RegistrySynchronizer1_2_KeepersUpdatedLog(t *testing.T) { } func Test_RegistrySynchronizer1_2_UpkeepCanceledLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_2) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -349,16 +352,17 @@ func Test_RegistrySynchronizer1_2_UpkeepCanceledLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) } func Test_RegistrySynchronizer1_2_UpkeepRegisteredLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_2) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -390,16 +394,17 @@ func Test_RegistrySynchronizer1_2_UpkeepRegisteredLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) } func Test_RegistrySynchronizer1_2_UpkeepPerformedLog(t *testing.T) { + ctx := testutils.Context(t) g := gomega.NewWithT(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_2) @@ -432,11 +437,11 @@ func Test_RegistrySynchronizer1_2_UpkeepPerformedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) g.Eventually(func() int64 { var upkeep keeper.UpkeepRegistration @@ -454,6 +459,7 @@ func Test_RegistrySynchronizer1_2_UpkeepPerformedLog(t *testing.T) { } func Test_RegistrySynchronizer1_2_UpkeepGasLimitSetLog(t *testing.T) { + ctx := testutils.Context(t) g := gomega.NewWithT(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_2) @@ -496,16 +502,17 @@ func Test_RegistrySynchronizer1_2_UpkeepGasLimitSetLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) g.Eventually(getExecuteGas, testutils.WaitTimeout(t), cltest.DBPollingInterval).Should(gomega.Equal(uint32(4_000_000))) } func Test_RegistrySynchronizer1_2_UpkeepReceivedLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_2) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -537,16 +544,17 @@ func Test_RegistrySynchronizer1_2_UpkeepReceivedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) } func Test_RegistrySynchronizer1_2_UpkeepMigratedLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_2) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -575,11 +583,11 @@ func Test_RegistrySynchronizer1_2_UpkeepMigratedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) } diff --git a/core/services/keeper/registry1_3_synchronizer_test.go b/core/services/keeper/registry1_3_synchronizer_test.go index 2b5900ac189..7ebbbc25469 100644 --- a/core/services/keeper/registry1_3_synchronizer_test.go +++ b/core/services/keeper/registry1_3_synchronizer_test.go @@ -225,6 +225,7 @@ func Test_RegistrySynchronizer1_3_FullSync(t *testing.T) { } func Test_RegistrySynchronizer1_3_ConfigSetLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -263,11 +264,11 @@ func Test_RegistrySynchronizer1_3_ConfigSetLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.AssertRecordEventually(t, db, ®istry, fmt.Sprintf(`SELECT * FROM keeper_registries WHERE id = %d`, registry.ID), func() bool { return registry.BlockCountPerTurn == 40 @@ -276,6 +277,7 @@ func Test_RegistrySynchronizer1_3_ConfigSetLog(t *testing.T) { } func Test_RegistrySynchronizer1_3_KeepersUpdatedLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -313,11 +315,11 @@ func Test_RegistrySynchronizer1_3_KeepersUpdatedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.AssertRecordEventually(t, db, ®istry, fmt.Sprintf(`SELECT * FROM keeper_registries WHERE id = %d`, registry.ID), func() bool { return registry.NumKeepers == 2 @@ -326,6 +328,7 @@ func Test_RegistrySynchronizer1_3_KeepersUpdatedLog(t *testing.T) { } func Test_RegistrySynchronizer1_3_UpkeepCanceledLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -354,16 +357,17 @@ func Test_RegistrySynchronizer1_3_UpkeepCanceledLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) } func Test_RegistrySynchronizer1_3_UpkeepRegisteredLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -395,16 +399,17 @@ func Test_RegistrySynchronizer1_3_UpkeepRegisteredLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) } func Test_RegistrySynchronizer1_3_UpkeepPerformedLog(t *testing.T) { + ctx := testutils.Context(t) g := gomega.NewWithT(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) @@ -437,11 +442,11 @@ func Test_RegistrySynchronizer1_3_UpkeepPerformedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) g.Eventually(func() int64 { var upkeep keeper.UpkeepRegistration @@ -459,6 +464,7 @@ func Test_RegistrySynchronizer1_3_UpkeepPerformedLog(t *testing.T) { } func Test_RegistrySynchronizer1_3_UpkeepGasLimitSetLog(t *testing.T) { + ctx := testutils.Context(t) g := gomega.NewWithT(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) @@ -501,16 +507,17 @@ func Test_RegistrySynchronizer1_3_UpkeepGasLimitSetLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) g.Eventually(getExecuteGas, testutils.WaitTimeout(t), cltest.DBPollingInterval).Should(gomega.Equal(uint32(4_000_000))) } func Test_RegistrySynchronizer1_3_UpkeepReceivedLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -542,16 +549,17 @@ func Test_RegistrySynchronizer1_3_UpkeepReceivedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) } func Test_RegistrySynchronizer1_3_UpkeepMigratedLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -580,17 +588,18 @@ func Test_RegistrySynchronizer1_3_UpkeepMigratedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) // race condition: "wait for count" cltest.WaitForCount(t, db, "upkeep_registrations", 2) } func Test_RegistrySynchronizer1_3_UpkeepPausedLog_UpkeepUnpausedLog(t *testing.T) { + ctx := testutils.Context(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) contractAddress := job.KeeperSpec.ContractAddress.Address() @@ -620,11 +629,11 @@ func Test_RegistrySynchronizer1_3_UpkeepPausedLog_UpkeepUnpausedLog(t *testing.T logBroadcast.On("DecodedLog").Return(&log) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 2) @@ -635,11 +644,11 @@ func Test_RegistrySynchronizer1_3_UpkeepPausedLog_UpkeepUnpausedLog(t *testing.T logBroadcast.On("DecodedLog").Return(&unpausedlog) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) cltest.WaitForCount(t, db, "upkeep_registrations", 3) var upkeep keeper.UpkeepRegistration @@ -657,6 +666,7 @@ func Test_RegistrySynchronizer1_3_UpkeepPausedLog_UpkeepUnpausedLog(t *testing.T } func Test_RegistrySynchronizer1_3_UpkeepCheckDataUpdatedLog(t *testing.T) { + ctx := testutils.Context(t) g := gomega.NewWithT(t) db, synchronizer, ethMock, lb, job := setupRegistrySync(t, keeper.RegistryVersion_1_3) @@ -694,11 +704,11 @@ func Test_RegistrySynchronizer1_3_UpkeepCheckDataUpdatedLog(t *testing.T) { logBroadcast.On("DecodedLog").Return(&updatedLog) logBroadcast.On("RawLog").Return(rawLog) logBroadcast.On("String").Maybe().Return("") - lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) // Do the thing - synchronizer.HandleLog(logBroadcast) + synchronizer.HandleLog(ctx, logBroadcast) g.Eventually(func() []byte { var upkeep keeper.UpkeepRegistration diff --git a/core/services/keeper/registry_synchronizer_log_listener.go b/core/services/keeper/registry_synchronizer_log_listener.go index 099d01d27f6..93ff2e9e950 100644 --- a/core/services/keeper/registry_synchronizer_log_listener.go +++ b/core/services/keeper/registry_synchronizer_log_listener.go @@ -1,6 +1,7 @@ package keeper import ( + "context" "reflect" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/log" @@ -10,7 +11,7 @@ func (rs *RegistrySynchronizer) JobID() int32 { return rs.job.ID } -func (rs *RegistrySynchronizer) HandleLog(broadcast log.Broadcast) { +func (rs *RegistrySynchronizer) HandleLog(ctx context.Context, broadcast log.Broadcast) { eventLog := broadcast.DecodedLog() if eventLog == nil || reflect.ValueOf(eventLog).IsNil() { rs.logger.Panicf("HandleLog: ignoring nil value, type: %T", broadcast) diff --git a/core/services/keeper/registry_synchronizer_process_logs.go b/core/services/keeper/registry_synchronizer_process_logs.go index 0a0e1613c95..a1bdcd8db0b 100644 --- a/core/services/keeper/registry_synchronizer_process_logs.go +++ b/core/services/keeper/registry_synchronizer_process_logs.go @@ -85,7 +85,7 @@ func (rs *RegistrySynchronizer) processLogs(ctx context.Context) { rs.logger.Error(err) } - err = rs.logBroadcaster.MarkConsumed(ctx, broadcast) + err = rs.logBroadcaster.MarkConsumed(ctx, nil, broadcast) if err != nil { rs.logger.Error(errors.Wrapf(err, "unable to mark %T log as consumed, log: %v", broadcast.RawLog(), broadcast.String())) } diff --git a/core/services/ocr/contract_tracker.go b/core/services/ocr/contract_tracker.go index e4845ee3bc2..5746f97cd38 100644 --- a/core/services/ocr/contract_tracker.go +++ b/core/services/ocr/contract_tracker.go @@ -14,13 +14,12 @@ import ( gethTypes "github.com/ethereum/go-ethereum/core/types" "github.com/pkg/errors" - "github.com/jmoiron/sqlx" - "github.com/smartcontractkit/libocr/gethwrappers/offchainaggregator" "github.com/smartcontractkit/libocr/offchainreporting/confighelper" ocrtypes "github.com/smartcontractkit/libocr/offchainreporting/types" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" "github.com/smartcontractkit/chainlink/v2/common/config" @@ -31,7 +30,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated/offchain_aggregator_wrapper" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocrcommon" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) // configMailboxSanityLimit is the maximum number of configs that can be held @@ -64,7 +62,7 @@ type ( jobID int32 logger logger.Logger ocrDB OCRContractTrackerDB - q pg.Q + ds sqlutil.DataSource blockTranslator ocrcommon.BlockTranslator cfg ocrcommon.Config mailMon *mailbox.Monitor @@ -92,8 +90,8 @@ type ( } OCRContractTrackerDB interface { - SaveLatestRoundRequested(tx pg.Queryer, rr offchainaggregator.OffchainAggregatorRoundRequested) error - LoadLatestRoundRequested() (rr offchainaggregator.OffchainAggregatorRoundRequested, err error) + SaveLatestRoundRequested(ctx context.Context, tx sqlutil.DataSource, rr offchainaggregator.OffchainAggregatorRoundRequested) error + LoadLatestRoundRequested(ctx context.Context) (rr offchainaggregator.OffchainAggregatorRoundRequested, err error) } ) @@ -112,10 +110,9 @@ func NewOCRContractTracker( logBroadcaster log.Broadcaster, jobID int32, logger logger.Logger, - db *sqlx.DB, + ds sqlutil.DataSource, ocrDB OCRContractTrackerDB, cfg ocrcommon.Config, - q pg.QConfig, headBroadcaster httypes.HeadBroadcaster, mailMon *mailbox.Monitor, ) (o *OCRContractTracker) { @@ -129,7 +126,7 @@ func NewOCRContractTracker( jobID: jobID, logger: logger, ocrDB: ocrDB, - q: pg.NewQ(db, logger, q), + ds: ds, blockTranslator: ocrcommon.NewBlockTranslator(cfg, ethClient, logger), cfg: cfg, mailMon: mailMon, @@ -144,9 +141,9 @@ func NewOCRContractTracker( // Start must be called before logs can be delivered // It ought to be called before starting OCR -func (t *OCRContractTracker) Start(context.Context) error { +func (t *OCRContractTracker) Start(ctx context.Context) error { return t.StartOnce("OCRContractTracker", func() (err error) { - t.latestRoundRequested, err = t.ocrDB.LoadLatestRoundRequested() + t.latestRoundRequested, err = t.ocrDB.LoadLatestRoundRequested(ctx) if err != nil { return errors.Wrap(err, "OCRContractTracker#Start: failed to load latest round requested") } @@ -240,10 +237,7 @@ func (t *OCRContractTracker) processLogs() { // HandleLog complies with LogListener interface // It is not thread safe -func (t *OCRContractTracker) HandleLog(lb log.Broadcast) { - ctx, cancel := t.chStop.NewCtx() - defer cancel() - +func (t *OCRContractTracker) HandleLog(ctx context.Context, lb log.Broadcast) { was, err := t.logBroadcaster.WasAlreadyConsumed(ctx, lb) if err != nil { t.logger.Errorw("could not determine if log was already consumed", "err", err) @@ -255,14 +249,14 @@ func (t *OCRContractTracker) HandleLog(lb log.Broadcast) { raw := lb.RawLog() if raw.Address != t.contract.Address() { t.logger.Errorf("log address of 0x%x does not match configured contract address of 0x%x", raw.Address, t.contract.Address()) - if err2 := t.logBroadcaster.MarkConsumed(ctx, lb); err2 != nil { + if err2 := t.logBroadcaster.MarkConsumed(ctx, nil, lb); err2 != nil { t.logger.Errorw("failed to mark log consumed", "err", err2) } return } topics := raw.Topics if len(topics) == 0 { - if err2 := t.logBroadcaster.MarkConsumed(ctx, lb); err2 != nil { + if err2 := t.logBroadcaster.MarkConsumed(ctx, nil, lb); err2 != nil { t.logger.Errorw("failed to mark log consumed", "err", err2) } return @@ -275,7 +269,7 @@ func (t *OCRContractTracker) HandleLog(lb log.Broadcast) { configSet, err = t.contractFilterer.ParseConfigSet(raw) if err != nil { t.logger.Errorw("could not parse config set", "err", err) - if err2 := t.logBroadcaster.MarkConsumed(ctx, lb); err2 != nil { + if err2 := t.logBroadcaster.MarkConsumed(ctx, nil, lb); err2 != nil { t.logger.Errorw("failed to mark log consumed", "err", err2) } return @@ -292,17 +286,17 @@ func (t *OCRContractTracker) HandleLog(lb log.Broadcast) { rr, err = t.contractFilterer.ParseRoundRequested(raw) if err != nil { t.logger.Errorw("could not parse round requested", "err", err) - if err2 := t.logBroadcaster.MarkConsumed(ctx, lb); err2 != nil { + if err2 := t.logBroadcaster.MarkConsumed(ctx, nil, lb); err2 != nil { t.logger.Errorw("failed to mark log consumed", "err", err2) } return } if IsLaterThan(raw, t.latestRoundRequested.Raw) { - err = t.q.Transaction(func(tx pg.Queryer) error { - if err = t.ocrDB.SaveLatestRoundRequested(tx, *rr); err != nil { + err = sqlutil.TransactDataSource(ctx, t.ds, nil, func(tx sqlutil.DataSource) error { + if err = t.ocrDB.SaveLatestRoundRequested(ctx, tx, *rr); err != nil { return err } - return t.logBroadcaster.MarkConsumed(ctx, lb) + return t.logBroadcaster.MarkConsumed(ctx, tx, lb) }) if err != nil { t.logger.Error(err) @@ -320,7 +314,7 @@ func (t *OCRContractTracker) HandleLog(lb log.Broadcast) { t.logger.Debugw("got unrecognised log topic", "topic", topics[0]) } if !consumed { - if err := t.logBroadcaster.MarkConsumed(ctx, lb); err != nil { + if err := t.logBroadcaster.MarkConsumed(ctx, nil, lb); err != nil { t.logger.Errorw("failed to mark log consumed", "err", err) } } diff --git a/core/services/ocr/contract_tracker_test.go b/core/services/ocr/contract_tracker_test.go index 678af35fa04..5473a2c924c 100644 --- a/core/services/ocr/contract_tracker_test.go +++ b/core/services/ocr/contract_tracker_test.go @@ -97,7 +97,6 @@ func newContractTrackerUni(t *testing.T, opts ...interface{}) (uni contractTrack db, uni.db, cfg.EVM(), - cfg.Database(), uni.hb, mailMon, ) @@ -146,7 +145,7 @@ func Test_OCRContractTracker_LatestBlockHeight(t *testing.T) { uni := newContractTrackerUni(t) uni.hb.On("Subscribe", uni.tracker).Return(&evmtypes.Head{Number: 42}, func() {}) - uni.db.On("LoadLatestRoundRequested").Return(offchainaggregator.OffchainAggregatorRoundRequested{}, nil) + uni.db.On("LoadLatestRoundRequested", mock.Anything).Return(offchainaggregator.OffchainAggregatorRoundRequested{}, nil) uni.lb.On("Register", uni.tracker, mock.Anything).Return(func() {}) servicetest.Run(t, uni.tracker) @@ -172,7 +171,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin rawLog := cltest.LogFromFixture(t, "../../testdata/jsonrpc/round_requested_log_1_1.json") logBroadcast.On("RawLog").Return(rawLog).Maybe() logBroadcast.On("String").Return("").Maybe() - uni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) configDigest, epoch, round, err := uni.tracker.LatestRoundRequested(testutils.Context(t), 0) @@ -181,7 +180,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin require.Equal(t, 0, int(round)) require.Equal(t, 0, int(epoch)) - uni.tracker.HandleLog(logBroadcast) + uni.tracker.HandleLog(testutils.Context(t), logBroadcast) configDigest, epoch, round, err = uni.tracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -203,7 +202,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin require.Equal(t, 0, int(round)) require.Equal(t, 0, int(epoch)) - uni.tracker.HandleLog(logBroadcast) + uni.tracker.HandleLog(testutils.Context(t), logBroadcast) configDigest, epoch, round, err = uni.tracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -228,13 +227,13 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin logBroadcast.On("RawLog").Return(rawLog).Maybe() logBroadcast.On("String").Return("").Maybe() uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - uni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) - uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.MatchedBy(func(rr offchainaggregator.OffchainAggregatorRoundRequested) bool { + uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.Anything, mock.MatchedBy(func(rr offchainaggregator.OffchainAggregatorRoundRequested) bool { return rr.Epoch == 1 && rr.Round == 1 })).Return(nil) - uni.tracker.HandleLog(logBroadcast) + uni.tracker.HandleLog(testutils.Context(t), logBroadcast) configDigest, epoch, round, err = uni.tracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -248,13 +247,13 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin logBroadcast2.On("RawLog").Return(rawLog2) logBroadcast2.On("String").Return("").Maybe() uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - uni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) - uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.MatchedBy(func(rr offchainaggregator.OffchainAggregatorRoundRequested) bool { + uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.Anything, mock.MatchedBy(func(rr offchainaggregator.OffchainAggregatorRoundRequested) bool { return rr.Epoch == 1 && rr.Round == 9 })).Return(nil) - uni.tracker.HandleLog(logBroadcast2) + uni.tracker.HandleLog(testutils.Context(t), logBroadcast2) configDigest, epoch, round, err = uni.tracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -263,7 +262,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin assert.Equal(t, 9, int(round)) // Same round with lower epoch is ignored - uni.tracker.HandleLog(logBroadcast) + uni.tracker.HandleLog(testutils.Context(t), logBroadcast) configDigest, epoch, round, err = uni.tracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -277,13 +276,13 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin logBroadcast3.On("RawLog").Return(rawLog3).Maybe() logBroadcast3.On("String").Return("").Maybe() uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - uni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) - uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.MatchedBy(func(rr offchainaggregator.OffchainAggregatorRoundRequested) bool { + uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.Anything, mock.MatchedBy(func(rr offchainaggregator.OffchainAggregatorRoundRequested) bool { return rr.Epoch == 2 && rr.Round == 1 })).Return(nil) - uni.tracker.HandleLog(logBroadcast3) + uni.tracker.HandleLog(testutils.Context(t), logBroadcast3) configDigest, epoch, round, err = uni.tracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -301,9 +300,9 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin logBroadcast.On("String").Return("").Maybe() uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.Anything).Return(errors.New("something exploded")) + uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("something exploded")) - uni.tracker.HandleLog(logBroadcast) + uni.tracker.HandleLog(testutils.Context(t), logBroadcast) configDigest, epoch, round, err := uni.tracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -331,7 +330,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin eventuallyCloseHeadBroadcaster := cltest.NewAwaiter() uni.hb.On("Subscribe", uni.tracker).Return((*evmtypes.Head)(nil), func() { eventuallyCloseHeadBroadcaster.ItHappened() }) - uni.db.On("LoadLatestRoundRequested").Return(rr, nil) + uni.db.On("LoadLatestRoundRequested", mock.Anything).Return(rr, nil) require.NoError(t, uni.tracker.Start(testutils.Context(t))) diff --git a/core/services/ocr/database.go b/core/services/ocr/database.go index 977c371c15d..95993de9d5c 100644 --- a/core/services/ocr/database.go +++ b/core/services/ocr/database.go @@ -11,17 +11,16 @@ import ( "github.com/pkg/errors" "go.uber.org/multierr" - "github.com/jmoiron/sqlx" "github.com/smartcontractkit/libocr/gethwrappers/offchainaggregator" ocrtypes "github.com/smartcontractkit/libocr/offchainreporting/types" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) type db struct { - q pg.Q + ds sqlutil.DataSource oracleSpecID int32 lggr logger.SugaredLogger } @@ -32,11 +31,9 @@ var ( ) // NewDB returns a new DB scoped to this oracleSpecID -func NewDB(sqlxDB *sqlx.DB, oracleSpecID int32, lggr logger.Logger, cfg pg.QConfig) *db { - namedLogger := lggr.Named("OCR.DB") - +func NewDB(ds sqlutil.DataSource, oracleSpecID int32, lggr logger.Logger) *db { return &db{ - q: pg.NewQ(sqlxDB, namedLogger, cfg), + ds: ds, oracleSpecID: oracleSpecID, lggr: logger.Sugared(lggr), } @@ -54,7 +51,7 @@ func (d *db) ReadState(ctx context.Context, cd ocrtypes.ConfigDigest) (ps *ocrty var tmp []int64 var highestSentEpochTmp int64 - err = d.q.QueryRowxContext(ctx, stmt, d.oracleSpecID, cd).Scan(&ps.Epoch, &highestSentEpochTmp, pq.Array(&tmp)) + err = d.ds.QueryRowxContext(ctx, stmt, d.oracleSpecID, cd).Scan(&ps.Epoch, &highestSentEpochTmp, pq.Array(&tmp)) if errors.Is(err, sql.ErrNoRows) { return nil, nil } @@ -90,7 +87,9 @@ func (d *db) WriteState(ctx context.Context, cd ocrtypes.ConfigDigest, state ocr NOW() ) ` - _, err := d.q.WithOpts(pg.WithLongQueryTimeout()).ExecContext( + ctx, cancel := context.WithTimeout(sqlutil.WithoutDefaultTimeout(ctx), time.Minute) + defer cancel() + _, err := d.ds.ExecContext( ctx, stmt, d.oracleSpecID, cd, state.Epoch, state.HighestSentEpoch, pq.Array(&highestReceivedEpoch), ) @@ -109,7 +108,7 @@ func (d *db) ReadConfig(ctx context.Context) (c *ocrtypes.ContractConfig, err er var signers [][]byte var transmitters [][]byte - err = d.q.QueryRowContext(ctx, stmt, d.oracleSpecID).Scan( + err = d.ds.QueryRowxContext(ctx, stmt, d.oracleSpecID).Scan( &c.ConfigDigest, (*pq.ByteaArray)(&signers), (*pq.ByteaArray)(&transmitters), @@ -155,7 +154,7 @@ func (d *db) WriteConfig(ctx context.Context, c ocrtypes.ContractConfig) error { encoded = EXCLUDED.encoded, updated_at = NOW() ` - _, err := d.q.ExecContext(ctx, stmt, d.oracleSpecID, c.ConfigDigest, pq.ByteaArray(signers), pq.ByteaArray(transmitters), c.Threshold, int(c.EncodedConfigVersion), c.Encoded) + _, err := d.ds.ExecContext(ctx, stmt, d.oracleSpecID, c.ConfigDigest, pq.ByteaArray(signers), pq.ByteaArray(transmitters), c.Threshold, int(c.EncodedConfigVersion), c.Encoded) return errors.Wrap(err, "WriteConfig failed") } @@ -201,14 +200,14 @@ func (d *db) StorePendingTransmission(ctx context.Context, k ocrtypes.ReportTime updated_at = NOW() ` - _, err := d.q.ExecContext(ctx, stmt, d.oracleSpecID, k.ConfigDigest, k.Epoch, k.Round, p.Time, median, p.SerializedReport, pq.ByteaArray(rs), pq.ByteaArray(ss), p.Vs[:]) + _, err := d.ds.ExecContext(ctx, stmt, d.oracleSpecID, k.ConfigDigest, k.Epoch, k.Round, p.Time, median, p.SerializedReport, pq.ByteaArray(rs), pq.ByteaArray(ss), p.Vs[:]) return errors.Wrap(err, "StorePendingTransmission failed") } func (d *db) PendingTransmissionsWithConfigDigest(ctx context.Context, cd ocrtypes.ConfigDigest) (map[ocrtypes.ReportTimestamp]ocrtypes.PendingTransmission, error) { //nolint sqlclosecheck false positive - rows, err := d.q.QueryContext(ctx, ` + rows, err := d.ds.QueryContext(ctx, ` SELECT config_digest, epoch, @@ -269,7 +268,9 @@ WHERE ocr_oracle_spec_id = $1 AND config_digest = $2 } func (d *db) DeletePendingTransmission(ctx context.Context, k ocrtypes.ReportTimestamp) (err error) { - _, err = d.q.WithOpts(pg.WithLongQueryTimeout()).ExecContext(ctx, ` + ctx, cancel := context.WithTimeout(sqlutil.WithoutDefaultTimeout(ctx), time.Minute) + defer cancel() + _, err = d.ds.ExecContext(ctx, ` DELETE FROM ocr_pending_transmissions WHERE ocr_oracle_spec_id = $1 AND config_digest = $2 AND epoch = $3 AND round = $4 `, d.oracleSpecID, k.ConfigDigest, k.Epoch, k.Round) @@ -280,7 +281,9 @@ WHERE ocr_oracle_spec_id = $1 AND config_digest = $2 AND epoch = $3 AND round = } func (d *db) DeletePendingTransmissionsOlderThan(ctx context.Context, t time.Time) (err error) { - _, err = d.q.WithOpts(pg.WithLongQueryTimeout()).ExecContext(ctx, ` + ctx, cancel := context.WithTimeout(sqlutil.WithoutDefaultTimeout(ctx), time.Minute) + defer cancel() + _, err = d.ds.ExecContext(ctx, ` DELETE FROM ocr_pending_transmissions WHERE ocr_oracle_spec_id = $1 AND time < $2 `, d.oracleSpecID, t) @@ -290,12 +293,12 @@ WHERE ocr_oracle_spec_id = $1 AND time < $2 return } -func (d *db) SaveLatestRoundRequested(tx pg.Queryer, rr offchainaggregator.OffchainAggregatorRoundRequested) error { +func (d *db) SaveLatestRoundRequested(ctx context.Context, tx sqlutil.DataSource, rr offchainaggregator.OffchainAggregatorRoundRequested) error { rawLog, err := json.Marshal(rr.Raw) if err != nil { return errors.Wrap(err, "could not marshal log as JSON") } - _, err = tx.Exec(` + _, err = tx.ExecContext(ctx, ` INSERT INTO ocr_latest_round_requested (ocr_oracle_spec_id, requester, config_digest, epoch, round, raw) VALUES ($1,$2,$3,$4,$5,$6) ON CONFLICT (ocr_oracle_spec_id) DO UPDATE SET requester = EXCLUDED.requester, @@ -308,8 +311,8 @@ VALUES ($1,$2,$3,$4,$5,$6) ON CONFLICT (ocr_oracle_spec_id) DO UPDATE SET return errors.Wrap(err, "could not save latest round requested") } -func (d *db) LoadLatestRoundRequested() (rr offchainaggregator.OffchainAggregatorRoundRequested, err error) { - rows, err := d.q.Query(` +func (d *db) LoadLatestRoundRequested(ctx context.Context) (rr offchainaggregator.OffchainAggregatorRoundRequested, err error) { + rows, err := d.ds.QueryContext(ctx, ` SELECT requester, config_digest, epoch, round, raw FROM ocr_latest_round_requested WHERE ocr_oracle_spec_id = $1 diff --git a/core/services/ocr/database_test.go b/core/services/ocr/database_test.go index 5ccf257b2bb..8b8d64c49c9 100644 --- a/core/services/ocr/database_test.go +++ b/core/services/ocr/database_test.go @@ -410,7 +410,8 @@ func Test_DB_LatestRoundRequested(t *testing.T) { } t.Run("saves latest round requested", func(t *testing.T) { - err := odb.SaveLatestRoundRequested(sqlDB, rr) + ctx := testutils.Context(t) + err := odb.SaveLatestRoundRequested(ctx, sqlDB, rr) require.NoError(t, err) rawLog.Index = 42 @@ -424,17 +425,18 @@ func Test_DB_LatestRoundRequested(t *testing.T) { Raw: rawLog, } - err = odb.SaveLatestRoundRequested(sqlDB, rr) + err = odb.SaveLatestRoundRequested(ctx, sqlDB, rr) require.NoError(t, err) }) t.Run("loads latest round requested", func(t *testing.T) { + ctx := testutils.Context(t) // There is no round for db2 - lrr, err := odb2.LoadLatestRoundRequested() + lrr, err := odb2.LoadLatestRoundRequested(ctx) require.NoError(t, err) require.Equal(t, 0, int(lrr.Epoch)) - lrr, err = odb.LoadLatestRoundRequested() + lrr, err = odb.LoadLatestRoundRequested(ctx) require.NoError(t, err) assert.Equal(t, rr, lrr) diff --git a/core/services/ocr/delegate.go b/core/services/ocr/delegate.go index bcdda397e20..63055543f88 100644 --- a/core/services/ocr/delegate.go +++ b/core/services/ocr/delegate.go @@ -28,7 +28,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" "github.com/smartcontractkit/chainlink/v2/core/services/ocrcommon" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/services/synchronization" "github.com/smartcontractkit/chainlink/v2/core/services/telemetry" @@ -82,10 +81,10 @@ func (d *Delegate) JobType() job.Type { return job.OffchainReporting } -func (d *Delegate) BeforeJobCreated(spec job.Job) {} -func (d *Delegate) AfterJobCreated(spec job.Job) {} -func (d *Delegate) BeforeJobDeleted(spec job.Job) {} -func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) BeforeJobCreated(spec job.Job) {} +func (d *Delegate) AfterJobCreated(spec job.Job) {} +func (d *Delegate) BeforeJobDeleted(spec job.Job) {} +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // ServicesForSpec returns the OCR services that need to run for this job func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services []job.ServiceCtx, err error) { @@ -121,7 +120,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] return nil, errors.Wrap(err, "could not instantiate NewOffchainAggregatorCaller") } - ocrDB := NewDB(d.db, concreteSpec.ID, lggr, d.cfg) + ocrDB := NewDB(d.db, concreteSpec.ID, lggr) tracker := NewOCRContractTracker( contract, @@ -134,7 +133,6 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] d.db, ocrDB, chain.Config().EVM(), - chain.Config().Database(), chain.HeadBroadcaster(), d.mailMon, ) diff --git a/core/services/ocr/helpers_internal_test.go b/core/services/ocr/helpers_internal_test.go index 57b669ef401..c6a3d1ac401 100644 --- a/core/services/ocr/helpers_internal_test.go +++ b/core/services/ocr/helpers_internal_test.go @@ -5,7 +5,6 @@ import ( "github.com/jmoiron/sqlx" - "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" ) @@ -14,5 +13,5 @@ func (c *ConfigOverriderImpl) ExportedUpdateFlagsStatus() error { } func NewTestDB(t *testing.T, sqldb *sqlx.DB, oracleSpecID int32) *db { - return NewDB(sqldb, oracleSpecID, logger.TestLogger(t), pgtest.NewQConfig(true)) + return NewDB(sqldb, oracleSpecID, logger.TestLogger(t)) } diff --git a/core/services/ocr/mocks/ocr_contract_tracker_db.go b/core/services/ocr/mocks/ocr_contract_tracker_db.go index 6724e418014..42eebf939d7 100644 --- a/core/services/ocr/mocks/ocr_contract_tracker_db.go +++ b/core/services/ocr/mocks/ocr_contract_tracker_db.go @@ -3,11 +3,13 @@ package mocks import ( + context "context" + mock "github.com/stretchr/testify/mock" offchainaggregator "github.com/smartcontractkit/libocr/gethwrappers/offchainaggregator" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" + sqlutil "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" ) // OCRContractTrackerDB is an autogenerated mock type for the OCRContractTrackerDB type @@ -15,9 +17,9 @@ type OCRContractTrackerDB struct { mock.Mock } -// LoadLatestRoundRequested provides a mock function with given fields: -func (_m *OCRContractTrackerDB) LoadLatestRoundRequested() (offchainaggregator.OffchainAggregatorRoundRequested, error) { - ret := _m.Called() +// LoadLatestRoundRequested provides a mock function with given fields: ctx +func (_m *OCRContractTrackerDB) LoadLatestRoundRequested(ctx context.Context) (offchainaggregator.OffchainAggregatorRoundRequested, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for LoadLatestRoundRequested") @@ -25,17 +27,17 @@ func (_m *OCRContractTrackerDB) LoadLatestRoundRequested() (offchainaggregator.O var r0 offchainaggregator.OffchainAggregatorRoundRequested var r1 error - if rf, ok := ret.Get(0).(func() (offchainaggregator.OffchainAggregatorRoundRequested, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) (offchainaggregator.OffchainAggregatorRoundRequested, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() offchainaggregator.OffchainAggregatorRoundRequested); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) offchainaggregator.OffchainAggregatorRoundRequested); ok { + r0 = rf(ctx) } else { r0 = ret.Get(0).(offchainaggregator.OffchainAggregatorRoundRequested) } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -43,17 +45,17 @@ func (_m *OCRContractTrackerDB) LoadLatestRoundRequested() (offchainaggregator.O return r0, r1 } -// SaveLatestRoundRequested provides a mock function with given fields: tx, rr -func (_m *OCRContractTrackerDB) SaveLatestRoundRequested(tx pg.Queryer, rr offchainaggregator.OffchainAggregatorRoundRequested) error { - ret := _m.Called(tx, rr) +// SaveLatestRoundRequested provides a mock function with given fields: ctx, tx, rr +func (_m *OCRContractTrackerDB) SaveLatestRoundRequested(ctx context.Context, tx sqlutil.DataSource, rr offchainaggregator.OffchainAggregatorRoundRequested) error { + ret := _m.Called(ctx, tx, rr) if len(ret) == 0 { panic("no return value specified for SaveLatestRoundRequested") } var r0 error - if rf, ok := ret.Get(0).(func(pg.Queryer, offchainaggregator.OffchainAggregatorRoundRequested) error); ok { - r0 = rf(tx, rr) + if rf, ok := ret.Get(0).(func(context.Context, sqlutil.DataSource, offchainaggregator.OffchainAggregatorRoundRequested) error); ok { + r0 = rf(ctx, tx, rr) } else { r0 = ret.Error(0) } diff --git a/core/services/ocr2/delegate.go b/core/services/ocr2/delegate.go index a00ed195903..da6d6a1b6e7 100644 --- a/core/services/ocr2/delegate.go +++ b/core/services/ocr2/delegate.go @@ -278,7 +278,7 @@ func (d *Delegate) BeforeJobCreated(spec job.Job) { } func (d *Delegate) AfterJobCreated(spec job.Job) {} func (d *Delegate) BeforeJobDeleted(spec job.Job) {} -func (d *Delegate) OnDeleteJob(ctx context.Context, jb job.Job, q pg.Queryer) error { +func (d *Delegate) OnDeleteJob(ctx context.Context, jb job.Job) error { // If the job spec is malformed in any way, we report the error but return nil so that // the job deletion itself isn't blocked. @@ -295,13 +295,13 @@ func (d *Delegate) OnDeleteJob(ctx context.Context, jb job.Job, q pg.Queryer) er } // we only have clean to do for the EVM if rid.Network == relay.EVM { - return d.cleanupEVM(ctx, jb, q, rid) + return d.cleanupEVM(ctx, jb, rid) } return nil } // cleanupEVM is a helper for clean up EVM specific state when a job is deleted -func (d *Delegate) cleanupEVM(ctx context.Context, jb job.Job, q pg.Queryer, relayID relay.ID) error { +func (d *Delegate) cleanupEVM(ctx context.Context, jb job.Job, relayID relay.ID) error { // If UnregisterFilter returns an // error, that means it failed to remove a valid active filter from the db. We do abort the job deletion // in that case, since it should be easy for the user to retry and will avoid leaving the db in diff --git a/core/services/ocr2/plugins/generic/pipeline_runner_adapter_test.go b/core/services/ocr2/plugins/generic/pipeline_runner_adapter_test.go index 647aaf59056..d8678844d25 100644 --- a/core/services/ocr2/plugins/generic/pipeline_runner_adapter_test.go +++ b/core/services/ocr2/plugins/generic/pipeline_runner_adapter_test.go @@ -43,7 +43,7 @@ func TestAdapter_Integration(t *testing.T) { require.NoError(t, err) keystore := keystore.NewInMemory(db, utils.FastScryptParams, logger, cfg.Database()) - pipelineORM := pipeline.NewORM(db, logger, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + pipelineORM := pipeline.NewORM(db, logger, cfg.JobPipeline().MaxSuccessfulRuns()) bridgesORM := bridges.NewORM(db) jobORM := job.NewORM(db, pipelineORM, bridgesORM, keystore, logger, cfg.Database()) pr := pipeline.NewRunner( diff --git a/core/services/ocrbootstrap/delegate.go b/core/services/ocrbootstrap/delegate.go index 9ed7cbea477..2d87cf80346 100644 --- a/core/services/ocrbootstrap/delegate.go +++ b/core/services/ocrbootstrap/delegate.go @@ -19,7 +19,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/validate" "github.com/smartcontractkit/chainlink/v2/core/services/ocrcommon" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/relay" ) @@ -190,6 +189,6 @@ func (d *Delegate) AfterJobCreated(spec job.Job) { func (d *Delegate) BeforeJobDeleted(spec job.Job) {} // OnDeleteJob satisfies the job.Delegate interface. -func (d *Delegate) OnDeleteJob(ctx context.Context, spec job.Job, q pg.Queryer) error { +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } diff --git a/core/services/ocrcommon/run_saver.go b/core/services/ocrcommon/run_saver.go index 6d85aa857a4..52ffb31cea0 100644 --- a/core/services/ocrcommon/run_saver.go +++ b/core/services/ocrcommon/run_saver.go @@ -2,15 +2,16 @@ package ocrcommon import ( "context" + "time" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" ) type Runner interface { - InsertFinishedRun(run *pipeline.Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error + InsertFinishedRun(ctx context.Context, ds sqlutil.DataSource, run *pipeline.Run, saveSuccessfulTaskRuns bool) error } type RunResultSaver struct { @@ -19,7 +20,7 @@ type RunResultSaver struct { maxSuccessfulRuns uint64 runResults chan *pipeline.Run pipelineRunner Runner - done chan struct{} + stopCh services.StopChan logger logger.Logger } @@ -36,7 +37,7 @@ func NewResultRunSaver(pipelineRunner Runner, maxSuccessfulRuns: maxSuccessfulRuns, runResults: make(chan *pipeline.Run, resultsWriteDepth), pipelineRunner: pipelineRunner, - done: make(chan struct{}), + stopCh: make(chan struct{}), logger: logger.Named("RunResultSaver"), } } @@ -55,6 +56,8 @@ func (r *RunResultSaver) Save(run *pipeline.Run) { func (r *RunResultSaver) Start(context.Context) error { return r.StartOnce("RunResultSaver", func() error { go func() { + ctx, cancel := r.stopCh.NewCtx() + defer cancel() for { select { case run := <-r.runResults: @@ -66,10 +69,10 @@ func (r *RunResultSaver) Start(context.Context) error { r.logger.Tracew("RunSaver: saving job run", "run", run) // We do not want save successful TaskRuns as OCR runs very frequently so a lot of records // are produced and the successful TaskRuns do not provide value. - if err := r.pipelineRunner.InsertFinishedRun(run, false); err != nil { + if err := r.pipelineRunner.InsertFinishedRun(ctx, nil, run, false); err != nil { r.logger.Errorw("error inserting finished results", "err", err) } - case <-r.done: + case <-r.stopCh: return } } @@ -80,7 +83,10 @@ func (r *RunResultSaver) Start(context.Context) error { func (r *RunResultSaver) Close() error { return r.StopOnce("RunResultSaver", func() error { - r.done <- struct{}{} + close(r.stopCh) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() // In the unlikely event that there are remaining runResults to write, // drain the channel and save them. @@ -88,7 +94,7 @@ func (r *RunResultSaver) Close() error { select { case run := <-r.runResults: r.logger.Infow("RunSaver: saving job run before exiting", "run", run) - if err := r.pipelineRunner.InsertFinishedRun(run, false); err != nil { + if err := r.pipelineRunner.InsertFinishedRun(ctx, nil, run, false); err != nil { r.logger.Errorw("error inserting finished results", "err", err) } default: diff --git a/core/services/ocrcommon/run_saver_test.go b/core/services/ocrcommon/run_saver_test.go index 7bfe60f2a06..a965792ca1f 100644 --- a/core/services/ocrcommon/run_saver_test.go +++ b/core/services/ocrcommon/run_saver_test.go @@ -25,7 +25,7 @@ func TestRunSaver(t *testing.T) { pipelineRunner.On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything, mock.Anything). Return(nil). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = int64(d) + args.Get(2).(*pipeline.Run).ID = int64(d) }). Once() rs.Save(&pipeline.Run{ID: int64(i)}) diff --git a/core/services/pipeline/helpers_test.go b/core/services/pipeline/helpers_test.go index 9ee2dc693f2..0bbdef7a7f2 100644 --- a/core/services/pipeline/helpers_test.go +++ b/core/services/pipeline/helpers_test.go @@ -5,6 +5,7 @@ import ( "github.com/google/uuid" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/bridges" "github.com/smartcontractkit/chainlink/v2/core/chains/legacyevm" ) @@ -63,3 +64,5 @@ func (t *ETHTxTask) HelperSetDependencies(legacyChains legacyevm.LegacyChainCont t.specGasLimit = specGasLimit t.jobType = jobType } + +func (o *orm) Prune(ds sqlutil.DataSource, pipelineSpecID int32) { o.prune(ds, pipelineSpecID) } diff --git a/core/services/pipeline/mocks/orm.go b/core/services/pipeline/mocks/orm.go index b06041767a1..fe9aa2823a4 100644 --- a/core/services/pipeline/mocks/orm.go +++ b/core/services/pipeline/mocks/orm.go @@ -8,10 +8,10 @@ import ( models "github.com/smartcontractkit/chainlink/v2/core/store/models" mock "github.com/stretchr/testify/mock" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" - pipeline "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" + sqlutil "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" + time "time" uuid "github.com/google/uuid" @@ -40,24 +40,17 @@ func (_m *ORM) Close() error { return r0 } -// CreateRun provides a mock function with given fields: run, qopts -func (_m *ORM) CreateRun(run *pipeline.Run, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, run) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// CreateRun provides a mock function with given fields: ctx, run +func (_m *ORM) CreateRun(ctx context.Context, run *pipeline.Run) error { + ret := _m.Called(ctx, run) if len(ret) == 0 { panic("no return value specified for CreateRun") } var r0 error - if rf, ok := ret.Get(0).(func(*pipeline.Run, ...pg.QOpt) error); ok { - r0 = rf(run, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run) error); ok { + r0 = rf(ctx, run) } else { r0 = ret.Error(0) } @@ -65,16 +58,9 @@ func (_m *ORM) CreateRun(run *pipeline.Run, qopts ...pg.QOpt) error { return r0 } -// CreateSpec provides a mock function with given fields: _a0, maxTaskTimeout, qopts -func (_m *ORM) CreateSpec(_a0 pipeline.Pipeline, maxTaskTimeout models.Interval, qopts ...pg.QOpt) (int32, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, _a0, maxTaskTimeout) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// CreateSpec provides a mock function with given fields: ctx, ds, _a2, maxTaskTimeout +func (_m *ORM) CreateSpec(ctx context.Context, ds pipeline.CreateDataSource, _a2 pipeline.Pipeline, maxTaskTimeout models.Interval) (int32, error) { + ret := _m.Called(ctx, ds, _a2, maxTaskTimeout) if len(ret) == 0 { panic("no return value specified for CreateSpec") @@ -82,17 +68,17 @@ func (_m *ORM) CreateSpec(_a0 pipeline.Pipeline, maxTaskTimeout models.Interval, var r0 int32 var r1 error - if rf, ok := ret.Get(0).(func(pipeline.Pipeline, models.Interval, ...pg.QOpt) (int32, error)); ok { - return rf(_a0, maxTaskTimeout, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, pipeline.CreateDataSource, pipeline.Pipeline, models.Interval) (int32, error)); ok { + return rf(ctx, ds, _a2, maxTaskTimeout) } - if rf, ok := ret.Get(0).(func(pipeline.Pipeline, models.Interval, ...pg.QOpt) int32); ok { - r0 = rf(_a0, maxTaskTimeout, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, pipeline.CreateDataSource, pipeline.Pipeline, models.Interval) int32); ok { + r0 = rf(ctx, ds, _a2, maxTaskTimeout) } else { r0 = ret.Get(0).(int32) } - if rf, ok := ret.Get(1).(func(pipeline.Pipeline, models.Interval, ...pg.QOpt) error); ok { - r1 = rf(_a0, maxTaskTimeout, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, pipeline.CreateDataSource, pipeline.Pipeline, models.Interval) error); ok { + r1 = rf(ctx, ds, _a2, maxTaskTimeout) } else { r1 = ret.Error(1) } @@ -100,17 +86,37 @@ func (_m *ORM) CreateSpec(_a0 pipeline.Pipeline, maxTaskTimeout models.Interval, return r0, r1 } -// DeleteRun provides a mock function with given fields: id -func (_m *ORM) DeleteRun(id int64) error { - ret := _m.Called(id) +// DataSource provides a mock function with given fields: +func (_m *ORM) DataSource() sqlutil.DataSource { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for DataSource") + } + + var r0 sqlutil.DataSource + if rf, ok := ret.Get(0).(func() sqlutil.DataSource); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(sqlutil.DataSource) + } + } + + return r0 +} + +// DeleteRun provides a mock function with given fields: ctx, id +func (_m *ORM) DeleteRun(ctx context.Context, id int64) error { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for DeleteRun") } var r0 error - if rf, ok := ret.Get(0).(func(int64) error); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = rf(ctx, id) } else { r0 = ret.Error(0) } @@ -136,9 +142,9 @@ func (_m *ORM) DeleteRunsOlderThan(_a0 context.Context, _a1 time.Duration) error return r0 } -// FindRun provides a mock function with given fields: id -func (_m *ORM) FindRun(id int64) (pipeline.Run, error) { - ret := _m.Called(id) +// FindRun provides a mock function with given fields: ctx, id +func (_m *ORM) FindRun(ctx context.Context, id int64) (pipeline.Run, error) { + ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for FindRun") @@ -146,17 +152,17 @@ func (_m *ORM) FindRun(id int64) (pipeline.Run, error) { var r0 pipeline.Run var r1 error - if rf, ok := ret.Get(0).(func(int64) (pipeline.Run, error)); ok { - return rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) (pipeline.Run, error)); ok { + return rf(ctx, id) } - if rf, ok := ret.Get(0).(func(int64) pipeline.Run); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(context.Context, int64) pipeline.Run); ok { + r0 = rf(ctx, id) } else { r0 = ret.Get(0).(pipeline.Run) } - if rf, ok := ret.Get(1).(func(int64) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, id) } else { r1 = ret.Error(1) } @@ -164,9 +170,9 @@ func (_m *ORM) FindRun(id int64) (pipeline.Run, error) { return r0, r1 } -// GetAllRuns provides a mock function with given fields: -func (_m *ORM) GetAllRuns() ([]pipeline.Run, error) { - ret := _m.Called() +// GetAllRuns provides a mock function with given fields: ctx +func (_m *ORM) GetAllRuns(ctx context.Context) ([]pipeline.Run, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for GetAllRuns") @@ -174,19 +180,19 @@ func (_m *ORM) GetAllRuns() ([]pipeline.Run, error) { var r0 []pipeline.Run var r1 error - if rf, ok := ret.Get(0).(func() ([]pipeline.Run, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) ([]pipeline.Run, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() []pipeline.Run); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []pipeline.Run); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]pipeline.Run) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -194,24 +200,6 @@ func (_m *ORM) GetAllRuns() ([]pipeline.Run, error) { return r0, r1 } -// GetQ provides a mock function with given fields: -func (_m *ORM) GetQ() pg.Q { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for GetQ") - } - - var r0 pg.Q - if rf, ok := ret.Get(0).(func() pg.Q); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(pg.Q) - } - - return r0 -} - // GetUnfinishedRuns provides a mock function with given fields: _a0, _a1, _a2 func (_m *ORM) GetUnfinishedRuns(_a0 context.Context, _a1 time.Time, _a2 func(pipeline.Run) error) error { ret := _m.Called(_a0, _a1, _a2) @@ -250,24 +238,17 @@ func (_m *ORM) HealthReport() map[string]error { return r0 } -// InsertFinishedRun provides a mock function with given fields: run, saveSuccessfulTaskRuns, qopts -func (_m *ORM) InsertFinishedRun(run *pipeline.Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, run, saveSuccessfulTaskRuns) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// InsertFinishedRun provides a mock function with given fields: ctx, run, saveSuccessfulTaskRuns +func (_m *ORM) InsertFinishedRun(ctx context.Context, run *pipeline.Run, saveSuccessfulTaskRuns bool) error { + ret := _m.Called(ctx, run, saveSuccessfulTaskRuns) if len(ret) == 0 { panic("no return value specified for InsertFinishedRun") } var r0 error - if rf, ok := ret.Get(0).(func(*pipeline.Run, bool, ...pg.QOpt) error); ok { - r0 = rf(run, saveSuccessfulTaskRuns, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run, bool) error); ok { + r0 = rf(ctx, run, saveSuccessfulTaskRuns) } else { r0 = ret.Error(0) } @@ -275,24 +256,17 @@ func (_m *ORM) InsertFinishedRun(run *pipeline.Run, saveSuccessfulTaskRuns bool, return r0 } -// InsertFinishedRunWithSpec provides a mock function with given fields: run, saveSuccessfulTaskRuns, qopts -func (_m *ORM) InsertFinishedRunWithSpec(run *pipeline.Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, run, saveSuccessfulTaskRuns) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// InsertFinishedRunWithSpec provides a mock function with given fields: ctx, run, saveSuccessfulTaskRuns +func (_m *ORM) InsertFinishedRunWithSpec(ctx context.Context, run *pipeline.Run, saveSuccessfulTaskRuns bool) error { + ret := _m.Called(ctx, run, saveSuccessfulTaskRuns) if len(ret) == 0 { panic("no return value specified for InsertFinishedRunWithSpec") } var r0 error - if rf, ok := ret.Get(0).(func(*pipeline.Run, bool, ...pg.QOpt) error); ok { - r0 = rf(run, saveSuccessfulTaskRuns, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run, bool) error); ok { + r0 = rf(ctx, run, saveSuccessfulTaskRuns) } else { r0 = ret.Error(0) } @@ -300,24 +274,17 @@ func (_m *ORM) InsertFinishedRunWithSpec(run *pipeline.Run, saveSuccessfulTaskRu return r0 } -// InsertFinishedRuns provides a mock function with given fields: run, saveSuccessfulTaskRuns, qopts -func (_m *ORM) InsertFinishedRuns(run []*pipeline.Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, run, saveSuccessfulTaskRuns) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// InsertFinishedRuns provides a mock function with given fields: ctx, run, saveSuccessfulTaskRuns +func (_m *ORM) InsertFinishedRuns(ctx context.Context, run []*pipeline.Run, saveSuccessfulTaskRuns bool) error { + ret := _m.Called(ctx, run, saveSuccessfulTaskRuns) if len(ret) == 0 { panic("no return value specified for InsertFinishedRuns") } var r0 error - if rf, ok := ret.Get(0).(func([]*pipeline.Run, bool, ...pg.QOpt) error); ok { - r0 = rf(run, saveSuccessfulTaskRuns, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, []*pipeline.Run, bool) error); ok { + r0 = rf(ctx, run, saveSuccessfulTaskRuns) } else { r0 = ret.Error(0) } @@ -325,24 +292,17 @@ func (_m *ORM) InsertFinishedRuns(run []*pipeline.Run, saveSuccessfulTaskRuns bo return r0 } -// InsertRun provides a mock function with given fields: run, qopts -func (_m *ORM) InsertRun(run *pipeline.Run, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, run) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// InsertRun provides a mock function with given fields: ctx, run +func (_m *ORM) InsertRun(ctx context.Context, run *pipeline.Run) error { + ret := _m.Called(ctx, run) if len(ret) == 0 { panic("no return value specified for InsertRun") } var r0 error - if rf, ok := ret.Get(0).(func(*pipeline.Run, ...pg.QOpt) error); ok { - r0 = rf(run, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run) error); ok { + r0 = rf(ctx, run) } else { r0 = ret.Error(0) } @@ -404,16 +364,9 @@ func (_m *ORM) Start(_a0 context.Context) error { return r0 } -// StoreRun provides a mock function with given fields: run, qopts -func (_m *ORM) StoreRun(run *pipeline.Run, qopts ...pg.QOpt) (bool, error) { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, run) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// StoreRun provides a mock function with given fields: ctx, run +func (_m *ORM) StoreRun(ctx context.Context, run *pipeline.Run) (bool, error) { + ret := _m.Called(ctx, run) if len(ret) == 0 { panic("no return value specified for StoreRun") @@ -421,17 +374,17 @@ func (_m *ORM) StoreRun(run *pipeline.Run, qopts ...pg.QOpt) (bool, error) { var r0 bool var r1 error - if rf, ok := ret.Get(0).(func(*pipeline.Run, ...pg.QOpt) (bool, error)); ok { - return rf(run, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run) (bool, error)); ok { + return rf(ctx, run) } - if rf, ok := ret.Get(0).(func(*pipeline.Run, ...pg.QOpt) bool); ok { - r0 = rf(run, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run) bool); ok { + r0 = rf(ctx, run) } else { r0 = ret.Get(0).(bool) } - if rf, ok := ret.Get(1).(func(*pipeline.Run, ...pg.QOpt) error); ok { - r1 = rf(run, qopts...) + if rf, ok := ret.Get(1).(func(context.Context, *pipeline.Run) error); ok { + r1 = rf(ctx, run) } else { r1 = ret.Error(1) } @@ -439,9 +392,27 @@ func (_m *ORM) StoreRun(run *pipeline.Run, qopts ...pg.QOpt) (bool, error) { return r0, r1 } -// UpdateTaskRunResult provides a mock function with given fields: taskID, result -func (_m *ORM) UpdateTaskRunResult(taskID uuid.UUID, result pipeline.Result) (pipeline.Run, bool, error) { - ret := _m.Called(taskID, result) +// Transact provides a mock function with given fields: _a0, _a1 +func (_m *ORM) Transact(_a0 context.Context, _a1 func(pipeline.ORM) error) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Transact") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, func(pipeline.ORM) error) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdateTaskRunResult provides a mock function with given fields: ctx, taskID, result +func (_m *ORM) UpdateTaskRunResult(ctx context.Context, taskID uuid.UUID, result pipeline.Result) (pipeline.Run, bool, error) { + ret := _m.Called(ctx, taskID, result) if len(ret) == 0 { panic("no return value specified for UpdateTaskRunResult") @@ -450,23 +421,23 @@ func (_m *ORM) UpdateTaskRunResult(taskID uuid.UUID, result pipeline.Result) (pi var r0 pipeline.Run var r1 bool var r2 error - if rf, ok := ret.Get(0).(func(uuid.UUID, pipeline.Result) (pipeline.Run, bool, error)); ok { - return rf(taskID, result) + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID, pipeline.Result) (pipeline.Run, bool, error)); ok { + return rf(ctx, taskID, result) } - if rf, ok := ret.Get(0).(func(uuid.UUID, pipeline.Result) pipeline.Run); ok { - r0 = rf(taskID, result) + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID, pipeline.Result) pipeline.Run); ok { + r0 = rf(ctx, taskID, result) } else { r0 = ret.Get(0).(pipeline.Run) } - if rf, ok := ret.Get(1).(func(uuid.UUID, pipeline.Result) bool); ok { - r1 = rf(taskID, result) + if rf, ok := ret.Get(1).(func(context.Context, uuid.UUID, pipeline.Result) bool); ok { + r1 = rf(ctx, taskID, result) } else { r1 = ret.Get(1).(bool) } - if rf, ok := ret.Get(2).(func(uuid.UUID, pipeline.Result) error); ok { - r2 = rf(taskID, result) + if rf, ok := ret.Get(2).(func(context.Context, uuid.UUID, pipeline.Result) error); ok { + r2 = rf(ctx, taskID, result) } else { r2 = ret.Error(2) } @@ -474,6 +445,26 @@ func (_m *ORM) UpdateTaskRunResult(taskID uuid.UUID, result pipeline.Result) (pi return r0, r1, r2 } +// WithDataSource provides a mock function with given fields: _a0 +func (_m *ORM) WithDataSource(_a0 sqlutil.DataSource) pipeline.ORM { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for WithDataSource") + } + + var r0 pipeline.ORM + if rf, ok := ret.Get(0).(func(sqlutil.DataSource) pipeline.ORM); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(pipeline.ORM) + } + } + + return r0 +} + // NewORM creates a new instance of ORM. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewORM(t interface { diff --git a/core/services/pipeline/mocks/runner.go b/core/services/pipeline/mocks/runner.go index 3de2703f0c7..e0378399f58 100644 --- a/core/services/pipeline/mocks/runner.go +++ b/core/services/pipeline/mocks/runner.go @@ -8,10 +8,10 @@ import ( logger "github.com/smartcontractkit/chainlink/v2/core/logger" mock "github.com/stretchr/testify/mock" - pg "github.com/smartcontractkit/chainlink/v2/core/services/pg" - pipeline "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" + sqlutil "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" + uuid "github.com/google/uuid" ) @@ -164,24 +164,17 @@ func (_m *Runner) InitializePipeline(spec pipeline.Spec) (*pipeline.Pipeline, er return r0, r1 } -// InsertFinishedRun provides a mock function with given fields: run, saveSuccessfulTaskRuns, qopts -func (_m *Runner) InsertFinishedRun(run *pipeline.Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, run, saveSuccessfulTaskRuns) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// InsertFinishedRun provides a mock function with given fields: ctx, ds, run, saveSuccessfulTaskRuns +func (_m *Runner) InsertFinishedRun(ctx context.Context, ds sqlutil.DataSource, run *pipeline.Run, saveSuccessfulTaskRuns bool) error { + ret := _m.Called(ctx, ds, run, saveSuccessfulTaskRuns) if len(ret) == 0 { panic("no return value specified for InsertFinishedRun") } var r0 error - if rf, ok := ret.Get(0).(func(*pipeline.Run, bool, ...pg.QOpt) error); ok { - r0 = rf(run, saveSuccessfulTaskRuns, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, sqlutil.DataSource, *pipeline.Run, bool) error); ok { + r0 = rf(ctx, ds, run, saveSuccessfulTaskRuns) } else { r0 = ret.Error(0) } @@ -189,24 +182,17 @@ func (_m *Runner) InsertFinishedRun(run *pipeline.Run, saveSuccessfulTaskRuns bo return r0 } -// InsertFinishedRuns provides a mock function with given fields: runs, saveSuccessfulTaskRuns, qopts -func (_m *Runner) InsertFinishedRuns(runs []*pipeline.Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error { - _va := make([]interface{}, len(qopts)) - for _i := range qopts { - _va[_i] = qopts[_i] - } - var _ca []interface{} - _ca = append(_ca, runs, saveSuccessfulTaskRuns) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// InsertFinishedRuns provides a mock function with given fields: ctx, ds, runs, saveSuccessfulTaskRuns +func (_m *Runner) InsertFinishedRuns(ctx context.Context, ds sqlutil.DataSource, runs []*pipeline.Run, saveSuccessfulTaskRuns bool) error { + ret := _m.Called(ctx, ds, runs, saveSuccessfulTaskRuns) if len(ret) == 0 { panic("no return value specified for InsertFinishedRuns") } var r0 error - if rf, ok := ret.Get(0).(func([]*pipeline.Run, bool, ...pg.QOpt) error); ok { - r0 = rf(runs, saveSuccessfulTaskRuns, qopts...) + if rf, ok := ret.Get(0).(func(context.Context, sqlutil.DataSource, []*pipeline.Run, bool) error); ok { + r0 = rf(ctx, ds, runs, saveSuccessfulTaskRuns) } else { r0 = ret.Error(0) } @@ -255,17 +241,17 @@ func (_m *Runner) Ready() error { return r0 } -// ResumeRun provides a mock function with given fields: taskID, value, err -func (_m *Runner) ResumeRun(taskID uuid.UUID, value interface{}, err error) error { - ret := _m.Called(taskID, value, err) +// ResumeRun provides a mock function with given fields: ctx, taskID, value, err +func (_m *Runner) ResumeRun(ctx context.Context, taskID uuid.UUID, value interface{}, err error) error { + ret := _m.Called(ctx, taskID, value, err) if len(ret) == 0 { panic("no return value specified for ResumeRun") } var r0 error - if rf, ok := ret.Get(0).(func(uuid.UUID, interface{}, error) error); ok { - r0 = rf(taskID, value, err) + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID, interface{}, error) error); ok { + r0 = rf(ctx, taskID, value, err) } else { r0 = ret.Error(0) } @@ -274,7 +260,7 @@ func (_m *Runner) ResumeRun(taskID uuid.UUID, value interface{}, err error) erro } // Run provides a mock function with given fields: ctx, run, l, saveSuccessfulTaskRuns, fn -func (_m *Runner) Run(ctx context.Context, run *pipeline.Run, l logger.Logger, saveSuccessfulTaskRuns bool, fn func(pg.Queryer) error) (bool, error) { +func (_m *Runner) Run(ctx context.Context, run *pipeline.Run, l logger.Logger, saveSuccessfulTaskRuns bool, fn func(sqlutil.DataSource) error) (bool, error) { ret := _m.Called(ctx, run, l, saveSuccessfulTaskRuns, fn) if len(ret) == 0 { @@ -283,16 +269,16 @@ func (_m *Runner) Run(ctx context.Context, run *pipeline.Run, l logger.Logger, s var r0 bool var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run, logger.Logger, bool, func(pg.Queryer) error) (bool, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run, logger.Logger, bool, func(sqlutil.DataSource) error) (bool, error)); ok { return rf(ctx, run, l, saveSuccessfulTaskRuns, fn) } - if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run, logger.Logger, bool, func(pg.Queryer) error) bool); ok { + if rf, ok := ret.Get(0).(func(context.Context, *pipeline.Run, logger.Logger, bool, func(sqlutil.DataSource) error) bool); ok { r0 = rf(ctx, run, l, saveSuccessfulTaskRuns, fn) } else { r0 = ret.Get(0).(bool) } - if rf, ok := ret.Get(1).(func(context.Context, *pipeline.Run, logger.Logger, bool, func(pg.Queryer) error) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, *pipeline.Run, logger.Logger, bool, func(sqlutil.DataSource) error) error); ok { r1 = rf(ctx, run, l, saveSuccessfulTaskRuns, fn) } else { r1 = ret.Error(1) diff --git a/core/services/pipeline/orm.go b/core/services/pipeline/orm.go index c32693e4db4..3bebfb8cbad 100644 --- a/core/services/pipeline/orm.go +++ b/core/services/pipeline/orm.go @@ -14,6 +14,7 @@ import ( "github.com/jmoiron/sqlx" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/pg" @@ -71,33 +72,42 @@ const KeepersObservationSource = ` encode_check_upkeep_tx -> check_upkeep_tx -> decode_check_upkeep_tx -> calculate_perform_data_len -> perform_data_lessthan_limit -> check_perform_data_limit -> encode_perform_upkeep_tx -> simulate_perform_upkeep_tx -> decode_check_perform_tx -> check_success -> perform_upkeep_tx ` +type CreateDataSource interface { + GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error +} + //go:generate mockery --quiet --name ORM --output ./mocks/ --case=underscore type ORM interface { services.Service - CreateSpec(pipeline Pipeline, maxTaskTimeout models.Interval, qopts ...pg.QOpt) (int32, error) - CreateRun(run *Run, qopts ...pg.QOpt) (err error) - InsertRun(run *Run, qopts ...pg.QOpt) error - DeleteRun(id int64) error - StoreRun(run *Run, qopts ...pg.QOpt) (restart bool, err error) - UpdateTaskRunResult(taskID uuid.UUID, result Result) (run Run, start bool, err error) - InsertFinishedRun(run *Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) (err error) - InsertFinishedRunWithSpec(run *Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) (err error) + + // ds is optional and to be removed after completing https://smartcontract-it.atlassian.net/browse/BCF-2978 + CreateSpec(ctx context.Context, ds CreateDataSource, pipeline Pipeline, maxTaskTimeout models.Interval) (int32, error) + CreateRun(ctx context.Context, run *Run) (err error) + InsertRun(ctx context.Context, run *Run) error + DeleteRun(ctx context.Context, id int64) error + StoreRun(ctx context.Context, run *Run) (restart bool, err error) + UpdateTaskRunResult(ctx context.Context, taskID uuid.UUID, result Result) (run Run, start bool, err error) + InsertFinishedRun(ctx context.Context, run *Run, saveSuccessfulTaskRuns bool) (err error) + InsertFinishedRunWithSpec(ctx context.Context, run *Run, saveSuccessfulTaskRuns bool) (err error) // InsertFinishedRuns inserts all the given runs into the database. // If saveSuccessfulTaskRuns is false, only errored runs are saved. - InsertFinishedRuns(run []*Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) (err error) + InsertFinishedRuns(ctx context.Context, run []*Run, saveSuccessfulTaskRuns bool) (err error) DeleteRunsOlderThan(context.Context, time.Duration) error - FindRun(id int64) (Run, error) - GetAllRuns() ([]Run, error) + FindRun(ctx context.Context, id int64) (Run, error) + GetAllRuns(ctx context.Context) ([]Run, error) GetUnfinishedRuns(context.Context, time.Time, func(run Run) error) error - GetQ() pg.Q + + DataSource() sqlutil.DataSource + WithDataSource(sqlutil.DataSource) ORM + Transact(context.Context, func(ORM) error) error } type orm struct { services.StateMachine - q pg.Q + ds sqlutil.DataSource lggr logger.Logger maxSuccessfulRuns uint64 // jobID => count @@ -109,17 +119,14 @@ type orm struct { var _ ORM = (*orm)(nil) -func NewORM(db *sqlx.DB, lggr logger.Logger, cfg pg.QConfig, jobPipelineMaxSuccessfulRuns uint64) *orm { +func NewORM(ds sqlutil.DataSource, lggr logger.Logger, jobPipelineMaxSuccessfulRuns uint64) *orm { ctx, cancel := context.WithCancel(context.Background()) return &orm{ - services.StateMachine{}, - pg.NewQ(db, lggr, cfg), - lggr.Named("PipelineORM"), - jobPipelineMaxSuccessfulRuns, - sync.Map{}, - sync.WaitGroup{}, - ctx, - cancel, + ds: ds, + lggr: lggr.Named("PipelineORM"), + maxSuccessfulRuns: jobPipelineMaxSuccessfulRuns, + ctx: ctx, + cncl: cancel, } } @@ -152,23 +159,56 @@ func (o *orm) HealthReport() map[string]error { return map[string]error{o.Name(): o.Healthy()} } -func (o *orm) CreateSpec(pipeline Pipeline, maxTaskDuration models.Interval, qopts ...pg.QOpt) (id int32, err error) { - q := o.q.WithOpts(qopts...) +func (o *orm) Transact(ctx context.Context, fn func(ORM) error) error { + return sqlutil.Transact(ctx, func(tx sqlutil.DataSource) ORM { + return o.withDataSource(tx) + }, o.ds, nil, func(tx ORM) error { + defer func() { + if err := tx.Close(); err != nil { + o.lggr.Warnw("Error closing temporary transactional ORM", "err", err) + } + }() + return fn(tx) + }) +} + +func (o *orm) DataSource() sqlutil.DataSource { return o.ds } + +func (o *orm) WithDataSource(ds sqlutil.DataSource) ORM { return o.withDataSource(ds) } + +func (o *orm) withDataSource(ds sqlutil.DataSource) *orm { + ctx, cancel := context.WithCancel(context.Background()) + return &orm{ + ds: ds, + lggr: o.lggr, + maxSuccessfulRuns: o.maxSuccessfulRuns, + ctx: ctx, + cncl: cancel, + } +} + +func (o *orm) transact(ctx context.Context, fn func(*orm) error) error { + return sqlutil.Transact(ctx, o.withDataSource, o.ds, nil, fn) +} + +func (o *orm) CreateSpec(ctx context.Context, ds CreateDataSource, pipeline Pipeline, maxTaskDuration models.Interval) (id int32, err error) { sql := `INSERT INTO pipeline_specs (dot_dag_source, max_task_duration, created_at) VALUES ($1, $2, NOW()) RETURNING id;` - err = q.Get(&id, sql, pipeline.Source, maxTaskDuration) + if ds == nil { + ds = o.ds + } + err = ds.GetContext(ctx, &id, sql, pipeline.Source, maxTaskDuration) return id, errors.WithStack(err) } -func (o *orm) CreateRun(run *Run, qopts ...pg.QOpt) (err error) { +func (o *orm) CreateRun(ctx context.Context, run *Run) (err error) { if run.CreatedAt.IsZero() { return errors.New("run.CreatedAt must be set") } - q := o.q.WithOpts(qopts...) - err = q.Transaction(func(tx pg.Queryer) error { - if e := o.InsertRun(run, pg.WithQueryer(tx)); e != nil { + err = o.transact(ctx, func(tx *orm) error { + if e := tx.InsertRun(ctx, run); e != nil { return errors.Wrap(e, "error inserting pipeline_run") } @@ -182,10 +222,9 @@ func (o *orm) CreateRun(run *Run, qopts ...pg.QOpt) (err error) { run.PipelineTaskRuns[i].PipelineRunID = run.ID } - sql := ` - INSERT INTO pipeline_task_runs (pipeline_run_id, id, type, index, output, error, dot_id, created_at) + sql := `INSERT INTO pipeline_task_runs (pipeline_run_id, id, type, index, output, error, dot_id, created_at) VALUES (:pipeline_run_id, :id, :type, :index, :output, :error, :dot_id, :created_at);` - _, err = tx.NamedExec(sql, run.PipelineTaskRuns) + _, err = tx.ds.NamedExecContext(ctx, sql, run.PipelineTaskRuns) return err }) @@ -193,33 +232,34 @@ func (o *orm) CreateRun(run *Run, qopts ...pg.QOpt) (err error) { } // InsertRun inserts a run into the database -func (o *orm) InsertRun(run *Run, qopts ...pg.QOpt) error { +func (o *orm) InsertRun(ctx context.Context, run *Run) error { if run.Status() == RunStatusCompleted { - defer o.Prune(o.q, run.PruningKey) + defer o.prune(o.ds, run.PruningKey) } - q := o.q.WithOpts(qopts...) - sql := `INSERT INTO pipeline_runs (pipeline_spec_id, pruning_key, meta, all_errors, fatal_errors, inputs, outputs, created_at, finished_at, state) + query, args, err := o.ds.BindNamed(`INSERT INTO pipeline_runs (pipeline_spec_id, pruning_key, meta, all_errors, fatal_errors, inputs, outputs, created_at, finished_at, state) VALUES (:pipeline_spec_id, :pruning_key, :meta, :all_errors, :fatal_errors, :inputs, :outputs, :created_at, :finished_at, :state) - RETURNING *;` - return q.GetNamed(sql, run, run) + RETURNING *;`, run) + if err != nil { + return fmt.Errorf("error binding arg: %w", err) + } + return o.ds.GetContext(ctx, run, query, args...) } // StoreRun will persist a partially executed run before suspending, or finish a run. // If `restart` is true, then new task run data is available and the run should be resumed immediately. -func (o *orm) StoreRun(run *Run, qopts ...pg.QOpt) (restart bool, err error) { - q := o.q.WithOpts(qopts...) - err = q.Transaction(func(tx pg.Queryer) error { +func (o *orm) StoreRun(ctx context.Context, run *Run) (restart bool, err error) { + err = o.transact(ctx, func(tx *orm) error { finished := run.FinishedAt.Valid if !finished { // Lock the current run. This prevents races with /v2/resume sql := `SELECT id FROM pipeline_runs WHERE id = $1 FOR UPDATE;` - if _, err = tx.Exec(sql, run.ID); err != nil { + if _, err = tx.ds.ExecContext(ctx, sql, run.ID); err != nil { return errors.Wrap(err, "StoreRun") } taskRuns := []TaskRun{} // Reload task runs, we want to check for any changes while the run was ongoing - if err = sqlx.Select(tx, &taskRuns, `SELECT * FROM pipeline_task_runs WHERE pipeline_run_id = $1`, run.ID); err != nil { + if err = tx.ds.SelectContext(ctx, &taskRuns, `SELECT * FROM pipeline_task_runs WHERE pipeline_run_id = $1`, run.ID); err != nil { return errors.Wrap(err, "StoreRun") } @@ -246,17 +286,17 @@ func (o *orm) StoreRun(run *Run, qopts ...pg.QOpt) (restart bool, err error) { // Suspend the run run.State = RunStatusSuspended - if _, err = sqlx.NamedExec(tx, `UPDATE pipeline_runs SET state = :state WHERE id = :id`, run); err != nil { + if _, err = tx.ds.NamedExecContext(ctx, `UPDATE pipeline_runs SET state = :state WHERE id = :id`, run); err != nil { return errors.Wrap(err, "StoreRun") } } else { - defer o.Prune(tx, run.PruningKey) + defer o.prune(tx.ds, run.PruningKey) // Simply finish the run, no need to do any sort of locking if run.Outputs.Val == nil || len(run.FatalErrors)+len(run.AllErrors) == 0 { return errors.Errorf("run must have both Outputs and Errors, got Outputs: %#v, FatalErrors: %#v, AllErrors: %#v", run.Outputs.Val, run.FatalErrors, run.AllErrors) } sql := `UPDATE pipeline_runs SET state = :state, finished_at = :finished_at, all_errors= :all_errors, fatal_errors= :fatal_errors, outputs = :outputs WHERE id = :id` - if _, err = sqlx.NamedExec(tx, sql, run); err != nil { + if _, err = tx.ds.NamedExecContext(ctx, sql, run); err != nil { return errors.Wrap(err, "StoreRun") } } @@ -272,7 +312,7 @@ func (o *orm) StoreRun(run *Run, qopts ...pg.QOpt) (restart bool, err error) { // NOTE: can't use Select() to auto scan because we're using NamedQuery, // sqlx.Named + Select is possible but it's about the same amount of code var rows *sqlx.Rows - rows, err = sqlx.NamedQuery(tx, sql, run.PipelineTaskRuns) + rows, err = sqlx.NamedQueryContext(ctx, tx.ds, sql, run.PipelineTaskRuns) if err != nil { return errors.Wrap(err, "StoreRun") } @@ -288,17 +328,17 @@ func (o *orm) StoreRun(run *Run, qopts ...pg.QOpt) (restart bool, err error) { } // DeleteRun cleans up a run that failed and is marked failEarly (should leave no trace of the run) -func (o *orm) DeleteRun(id int64) error { +func (o *orm) DeleteRun(ctx context.Context, id int64) error { // NOTE: this will cascade and wipe pipeline_task_runs too - _, err := o.q.Exec(`DELETE FROM pipeline_runs WHERE id = $1`, id) + _, err := o.ds.ExecContext(ctx, `DELETE FROM pipeline_runs WHERE id = $1`, id) return err } -func (o *orm) UpdateTaskRunResult(taskID uuid.UUID, result Result) (run Run, start bool, err error) { +func (o *orm) UpdateTaskRunResult(ctx context.Context, taskID uuid.UUID, result Result) (run Run, start bool, err error) { if result.OutputDB().Valid && result.ErrorDB().Valid { panic("run result must specify either output or error, not both") } - err = o.q.Transaction(func(tx pg.Queryer) error { + err = o.transact(ctx, func(tx *orm) error { sql := ` SELECT pipeline_runs.*, pipeline_specs.dot_dag_source "pipeline_spec.dot_dag_source", job_pipeline_specs.job_id "job_id" FROM pipeline_runs @@ -307,13 +347,13 @@ func (o *orm) UpdateTaskRunResult(taskID uuid.UUID, result Result) (run Run, sta JOIN job_pipeline_specs ON (job_pipeline_specs.pipeline_spec_id = pipeline_specs.id) WHERE pipeline_task_runs.id = $1 AND pipeline_runs.state in ('running', 'suspended') FOR UPDATE` - if err = tx.Get(&run, sql, taskID); err != nil { + if err = tx.ds.GetContext(ctx, &run, sql, taskID); err != nil { return fmt.Errorf("failed to find pipeline run for task ID %s: %w", taskID.String(), err) } // Update the task with result sql = `UPDATE pipeline_task_runs SET output = $2, error = $3, finished_at = $4 WHERE id = $1` - if _, err = tx.Exec(sql, taskID, result.OutputDB(), result.ErrorDB(), time.Now()); err != nil { + if _, err = tx.ds.ExecContext(ctx, sql, taskID, result.OutputDB(), result.ErrorDB(), time.Now()); err != nil { return fmt.Errorf("failed to update pipeline task run: %w", err) } @@ -322,21 +362,20 @@ func (o *orm) UpdateTaskRunResult(taskID uuid.UUID, result Result) (run Run, sta run.State = RunStatusRunning sql = `UPDATE pipeline_runs SET state = $2 WHERE id = $1` - if _, err = tx.Exec(sql, run.ID, run.State); err != nil { + if _, err = tx.ds.ExecContext(ctx, sql, run.ID, run.State); err != nil { return fmt.Errorf("failed to update pipeline run state: %w", err) } } - return loadAssociations(tx, []*Run{&run}) + return loadAssociations(ctx, tx.ds, []*Run{&run}) }) return run, start, err } // InsertFinishedRuns inserts all the given runs into the database. -func (o *orm) InsertFinishedRuns(runs []*Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error { - q := o.q.WithOpts(qopts...) - err := q.Transaction(func(tx pg.Queryer) error { +func (o *orm) InsertFinishedRuns(ctx context.Context, runs []*Run, saveSuccessfulTaskRuns bool) error { + err := o.transact(ctx, func(tx *orm) error { pipelineRunsQuery := ` INSERT INTO pipeline_runs (pipeline_spec_id, pruning_key, meta, all_errors, fatal_errors, inputs, outputs, created_at, finished_at, state) @@ -344,7 +383,7 @@ VALUES (:pipeline_spec_id, :pruning_key, :meta, :all_errors, :fatal_errors, :inputs, :outputs, :created_at, :finished_at, :state) RETURNING id ` - rows, errQ := tx.NamedQuery(pipelineRunsQuery, runs) + rows, errQ := sqlx.NamedQueryContext(ctx, tx.ds, pipelineRunsQuery, runs) if errQ != nil { return errors.Wrap(errQ, "inserting finished pipeline runs") } @@ -369,7 +408,7 @@ RETURNING id defer func() { for pruningKey := range pruningKeysm { - o.Prune(tx, pruningKey) + o.prune(tx.ds, pruningKey) } }() @@ -385,7 +424,7 @@ VALUES (:pipeline_run_id, :id, :type, :index, :output, :error, :dot_id, :created pipelineTaskRuns = append(pipelineTaskRuns, run.PipelineTaskRuns...) } - _, errE := tx.NamedExec(pipelineTaskRunsQuery, pipelineTaskRuns) + _, errE := tx.ds.NamedExecContext(ctx, pipelineTaskRunsQuery, pipelineTaskRuns) return errors.Wrap(errE, "insert pipeline task runs") }) return errors.Wrap(err, "InsertFinishedRuns failed") @@ -411,7 +450,7 @@ func (o *orm) checkFinishedRun(run *Run, saveSuccessfulTaskRuns bool) error { // If saveSuccessfulTaskRuns = false, we only save errored runs. // That way if the job is run frequently (such as OCR) we avoid saving a large number of successful task runs // which do not provide much value. -func (o *orm) InsertFinishedRun(run *Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) (err error) { +func (o *orm) InsertFinishedRun(ctx context.Context, run *Run, saveSuccessfulTaskRuns bool) (err error) { if err = o.checkFinishedRun(run, saveSuccessfulTaskRuns); err != nil { return err } @@ -421,13 +460,12 @@ func (o *orm) InsertFinishedRun(run *Run, saveSuccessfulTaskRuns bool, qopts ... return nil } - q := o.q.WithOpts(qopts...) - err = q.Transaction(o.insertFinishedRunTx(run, saveSuccessfulTaskRuns)) + err = o.insertFinishedRun(ctx, run, saveSuccessfulTaskRuns) return errors.Wrap(err, "InsertFinishedRun failed") } // InsertFinishedRunWithSpec works like InsertFinishedRun but also inserts the pipeline spec. -func (o *orm) InsertFinishedRunWithSpec(run *Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) (err error) { +func (o *orm) InsertFinishedRunWithSpec(ctx context.Context, run *Run, saveSuccessfulTaskRuns bool) (err error) { if err = o.checkFinishedRun(run, saveSuccessfulTaskRuns); err != nil { return err } @@ -437,57 +475,55 @@ func (o *orm) InsertFinishedRunWithSpec(run *Run, saveSuccessfulTaskRuns bool, q return nil } - q := o.q.WithOpts(qopts...) - err = q.Transaction(func(tx pg.Queryer) error { + err = o.transact(ctx, func(tx *orm) error { sqlStmt1 := `INSERT INTO pipeline_specs (dot_dag_source, max_task_duration, created_at) VALUES ($1, $2, NOW()) RETURNING id;` - err = tx.Get(&run.PipelineSpecID, sqlStmt1, run.PipelineSpec.DotDagSource, run.PipelineSpec.MaxTaskDuration) + err = tx.ds.GetContext(ctx, &run.PipelineSpecID, sqlStmt1, run.PipelineSpec.DotDagSource, run.PipelineSpec.MaxTaskDuration) if err != nil { return errors.Wrap(err, "failed to insert pipeline_specs") } // This `job_pipeline_specs` record won't be primary since when this method is called, the job already exists, so it will have primary record. sqlStmt2 := `INSERT INTO job_pipeline_specs (job_id, pipeline_spec_id, is_primary) VALUES ($1, $2, false)` - _, err = tx.Exec(sqlStmt2, run.JobID, run.PipelineSpecID) + _, err = tx.ds.ExecContext(ctx, sqlStmt2, run.JobID, run.PipelineSpecID) if err != nil { return errors.Wrap(err, "failed to insert job_pipeline_specs") } - return o.insertFinishedRunTx(run, saveSuccessfulTaskRuns)(tx) + return tx.insertFinishedRun(ctx, run, saveSuccessfulTaskRuns) }) return errors.Wrap(err, "InsertFinishedRun failed") } -func (o *orm) insertFinishedRunTx(run *Run, saveSuccessfulTaskRuns bool) func(tx pg.Queryer) error { - return func(tx pg.Queryer) error { - sql := `INSERT INTO pipeline_runs (pipeline_spec_id, pruning_key, meta, all_errors, fatal_errors, inputs, outputs, created_at, finished_at, state) +func (o *orm) insertFinishedRun(ctx context.Context, run *Run, saveSuccessfulTaskRuns bool) error { + sql := `INSERT INTO pipeline_runs (pipeline_spec_id, pruning_key, meta, all_errors, fatal_errors, inputs, outputs, created_at, finished_at, state) VALUES (:pipeline_spec_id, :pruning_key, :meta, :all_errors, :fatal_errors, :inputs, :outputs, :created_at, :finished_at, :state) RETURNING id;` - query, args, e := tx.BindNamed(sql, run) - if e != nil { - return errors.Wrap(e, "failed to bind") - } + query, args, err := o.ds.BindNamed(sql, run) + if err != nil { + return errors.Wrap(err, "failed to bind") + } - if err := tx.QueryRowx(query, args...).Scan(&run.ID); err != nil { - return errors.Wrap(err, "error inserting finished pipeline_run") - } + if err = o.ds.QueryRowxContext(ctx, query, args...).Scan(&run.ID); err != nil { + return errors.Wrap(err, "error inserting finished pipeline_run") + } - // update the ID key everywhere - for i := range run.PipelineTaskRuns { - run.PipelineTaskRuns[i].PipelineRunID = run.ID - } + // update the ID key everywhere + for i := range run.PipelineTaskRuns { + run.PipelineTaskRuns[i].PipelineRunID = run.ID + } - if !saveSuccessfulTaskRuns && !run.HasErrors() { - return nil - } + if !saveSuccessfulTaskRuns && !run.HasErrors() { + return nil + } - defer o.Prune(tx, run.PruningKey) - sql = ` + defer o.prune(o.ds, run.PruningKey) + sql = ` INSERT INTO pipeline_task_runs (pipeline_run_id, id, type, index, output, error, dot_id, created_at, finished_at) VALUES (:pipeline_run_id, :id, :type, :index, :output, :error, :dot_id, :created_at, :finished_at);` - _, err := tx.NamedExec(sql, run.PipelineTaskRuns) - return errors.Wrap(err, "failed to insert pipeline_task_runs") - } + _, err = o.ds.NamedExecContext(ctx, sql, run.PipelineTaskRuns) + return errors.Wrap(err, "failed to insert pipeline_task_runs") + } // DeleteRunsOlderThan deletes all pipeline_runs that have been finished for a certain threshold to free DB space @@ -495,14 +531,12 @@ func (o *orm) insertFinishedRunTx(run *Run, saveSuccessfulTaskRuns bool) func(tx func (o *orm) DeleteRunsOlderThan(ctx context.Context, threshold time.Duration) error { start := time.Now() - q := o.q.WithOpts(pg.WithParentCtxInheritTimeout(ctx)) - queryThreshold := start.Add(-threshold) rowsDeleted := int64(0) err := pg.Batch(func(_, limit uint) (count uint, err error) { - result, cancel, err := q.ExecQIter(` + result, err := o.ds.ExecContext(ctx, ` WITH batched_pipeline_runs AS ( SELECT * FROM pipeline_runs WHERE finished_at < ($1) @@ -515,7 +549,6 @@ WHERE pipeline_runs.id = batched_pipeline_runs.id`, queryThreshold, limit, ) - defer cancel() if err != nil { return count, errors.Wrap(err, "DeleteRunsOlderThan failed to delete old pipeline_runs") } @@ -539,7 +572,7 @@ WHERE pipeline_runs.id = batched_pipeline_runs.id`, o.lggr.Debugw("pipeline_runs reaper VACUUM ANALYZE query completed", "duration", time.Since(start)) }(deleteTS) - err = q.ExecQ("VACUUM ANALYZE pipeline_runs") + _, err = o.ds.ExecContext(ctx, "VACUUM ANALYZE pipeline_runs") if err != nil { o.lggr.Warnw("DeleteRunsOlderThan successfully deleted old pipeline_runs rows, but failed to run VACUUM ANALYZE", "err", err) return nil @@ -548,13 +581,13 @@ WHERE pipeline_runs.id = batched_pipeline_runs.id`, return nil } -func (o *orm) FindRun(id int64) (r Run, err error) { +func (o *orm) FindRun(ctx context.Context, id int64) (r Run, err error) { var runs []*Run - err = o.q.Transaction(func(tx pg.Queryer) error { - if err = tx.Select(&runs, `SELECT * from pipeline_runs WHERE id = $1 LIMIT 1`, id); err != nil { + err = o.transact(ctx, func(tx *orm) error { + if err = tx.ds.SelectContext(ctx, &runs, `SELECT * from pipeline_runs WHERE id = $1 LIMIT 1`, id); err != nil { return errors.Wrap(err, "failed to load runs") } - return loadAssociations(tx, runs) + return loadAssociations(ctx, tx.ds, runs) }) if len(runs) == 0 { return r, sql.ErrNoRows @@ -562,15 +595,15 @@ func (o *orm) FindRun(id int64) (r Run, err error) { return *runs[0], err } -func (o *orm) GetAllRuns() (runs []Run, err error) { +func (o *orm) GetAllRuns(ctx context.Context) (runs []Run, err error) { var runsPtrs []*Run - err = o.q.Transaction(func(tx pg.Queryer) error { - err = tx.Select(&runsPtrs, `SELECT * from pipeline_runs ORDER BY created_at ASC, id ASC`) + err = o.transact(ctx, func(tx *orm) error { + err = tx.ds.SelectContext(ctx, &runsPtrs, `SELECT * from pipeline_runs ORDER BY created_at ASC, id ASC`) if err != nil { return errors.Wrap(err, "failed to load runs") } - return loadAssociations(tx, runsPtrs) + return loadAssociations(ctx, tx.ds, runsPtrs) }) runs = make([]Run, len(runsPtrs)) for i, runPtr := range runsPtrs { @@ -580,17 +613,16 @@ func (o *orm) GetAllRuns() (runs []Run, err error) { } func (o *orm) GetUnfinishedRuns(ctx context.Context, now time.Time, fn func(run Run) error) error { - q := o.q.WithOpts(pg.WithParentCtx(ctx)) return pg.Batch(func(offset, limit uint) (count uint, err error) { var runs []*Run - err = q.Transaction(func(tx pg.Queryer) error { - err = tx.Select(&runs, `SELECT * from pipeline_runs WHERE state = $1 AND created_at < $2 ORDER BY created_at ASC, id ASC OFFSET $3 LIMIT $4`, RunStatusRunning, now, offset, limit) + err = o.transact(ctx, func(tx *orm) error { + err = tx.ds.SelectContext(ctx, &runs, `SELECT * from pipeline_runs WHERE state = $1 AND created_at < $2 ORDER BY created_at ASC, id ASC OFFSET $3 LIMIT $4`, RunStatusRunning, now, offset, limit) if err != nil { return errors.Wrap(err, "failed to load runs") } - err = loadAssociations(tx, runs) + err = loadAssociations(ctx, tx.ds, runs) if err != nil { return err } @@ -608,7 +640,7 @@ func (o *orm) GetUnfinishedRuns(ctx context.Context, now time.Time, fn func(run } // loads PipelineSpec and PipelineTaskRuns for Runs in exactly 2 queries -func loadAssociations(q pg.Queryer, runs []*Run) error { +func loadAssociations(ctx context.Context, ds sqlutil.DataSource, runs []*Run) error { if len(runs) == 0 { return nil } @@ -635,7 +667,7 @@ func loadAssociations(q pg.Queryer, runs []*Run) error { LEFT JOIN job_pipeline_specs jps ON jps.pipeline_spec_id=ps.id LEFT JOIN jobs ON jobs.id=jps.job_id WHERE ps.id = ANY($1)` - if err := q.Select(&specs, sqlQuery, pipelineSpecIDs); err != nil { + if err := ds.SelectContext(ctx, &specs, sqlQuery, pipelineSpecIDs); err != nil { return errors.Wrap(err, "failed to postload pipeline_specs for runs") } for _, spec := range specs { @@ -647,7 +679,7 @@ func loadAssociations(q pg.Queryer, runs []*Run) error { var taskRuns []TaskRun taskRunPRIDM := make(map[int64][]TaskRun, len(runs)) // keyed by pipelineRunID - if err := q.Select(&taskRuns, `SELECT * FROM pipeline_task_runs WHERE pipeline_run_id = ANY($1) ORDER BY created_at ASC, id ASC`, pipelineRunIDs); err != nil { + if err := ds.SelectContext(ctx, &taskRuns, `SELECT * FROM pipeline_task_runs WHERE pipeline_run_id = ANY($1) ORDER BY created_at ASC, id ASC`, pipelineRunIDs); err != nil { return errors.Wrap(err, "failed to postload pipeline_task_runs for runs") } for _, taskRun := range taskRuns { @@ -662,10 +694,6 @@ func loadAssociations(q pg.Queryer, runs []*Run) error { return nil } -func (o *orm) GetQ() pg.Q { - return o.q -} - func (o *orm) loadCount(jobID int32) *atomic.Uint64 { // fast path; avoids allocation actual, exists := o.pm.Load(jobID) @@ -681,7 +709,7 @@ func (o *orm) loadCount(jobID int32) *atomic.Uint64 { // this value or higher const syncLimit = 1000 -// Prune attempts to keep the pipeline_runs table capped close to the +// prune attempts to keep the pipeline_runs table capped close to the // maxSuccessfulRuns length for each job_id. // // It does this synchronously for small values and async/sampled for large @@ -689,13 +717,13 @@ const syncLimit = 1000 // // Note this does not guarantee the pipeline_runs table is kept to exactly the // max length, rather that it doesn't excessively larger than it. -func (o *orm) Prune(tx pg.Queryer, jobID int32) { +func (o *orm) prune(ds sqlutil.DataSource, jobID int32) { if jobID == 0 { o.lggr.Panic("expected a non-zero job ID") } // For small maxSuccessfulRuns its fast enough to prune every time if o.maxSuccessfulRuns < syncLimit { - o.execPrune(tx, jobID) + o.execPrune(o.ctx, ds, jobID) return } // for large maxSuccessfulRuns we do it async on a sampled basis @@ -708,9 +736,11 @@ func (o *orm) Prune(tx pg.Queryer, jobID int32) { go func() { o.lggr.Debugw("Pruning runs", "jobID", jobID, "count", val, "every", every, "maxSuccessfulRuns", o.maxSuccessfulRuns) defer o.wg.Done() - // Must not use tx here since it's async and the transaction + // Must not use ds here since it's async and the transaction // could be stale - o.execPrune(o.q.WithOpts(pg.WithLongQueryTimeout()), jobID) + ctx, cancel := context.WithTimeout(sqlutil.WithoutDefaultTimeout(o.ctx), time.Minute) + defer cancel() + o.execPrune(ctx, o.ds, jobID) }() }) if !ok { @@ -720,8 +750,8 @@ func (o *orm) Prune(tx pg.Queryer, jobID int32) { } } -func (o *orm) execPrune(q pg.Queryer, jobID int32) { - res, err := q.ExecContext(o.ctx, `DELETE FROM pipeline_runs WHERE pruning_key = $1 AND state = $2 AND id NOT IN ( +func (o *orm) execPrune(ctx context.Context, ds sqlutil.DataSource, jobID int32) { + res, err := ds.ExecContext(o.ctx, `DELETE FROM pipeline_runs WHERE pruning_key = $1 AND state = $2 AND id NOT IN ( SELECT id FROM pipeline_runs WHERE pruning_key = $1 AND state = $2 ORDER BY id DESC @@ -739,7 +769,7 @@ LIMIT $3 if rowsAffected == 0 { // check the spec still exists and garbage collect if necessary var exists bool - if err := q.GetContext(o.ctx, &exists, `SELECT EXISTS(SELECT ps.* FROM pipeline_specs ps JOIN job_pipeline_specs jps ON (ps.id=jps.pipeline_spec_id) WHERE jps.job_id = $1)`, jobID); err != nil { + if err := ds.GetContext(ctx, &exists, `SELECT EXISTS(SELECT ps.* FROM pipeline_specs ps JOIN job_pipeline_specs jps ON (ps.id=jps.pipeline_spec_id) WHERE jps.job_id = $1)`, jobID); err != nil { o.lggr.Errorw("Failed check existence of pipeline_spec while pruning runs", "err", err, "jobID", jobID) return } diff --git a/core/services/pipeline/orm_test.go b/core/services/pipeline/orm_test.go index e5bf319f056..88155bc04ba 100644 --- a/core/services/pipeline/orm_test.go +++ b/core/services/pipeline/orm_test.go @@ -71,7 +71,7 @@ func setupORM(t *testing.T, heavy bool) (db *sqlx.DB, orm pipeline.ORM, jorm job db = pgtest.NewSqlxDB(t) } cfg := ormconfig{pgtest.NewQConfig(true)} - orm = pipeline.NewORM(db, logger.TestLogger(t), cfg, cfg.JobPipelineMaxSuccessfulRuns()) + orm = pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipelineMaxSuccessfulRuns()) config := configtest.NewTestGeneralConfig(t) lggr := logger.TestLogger(t) keyStore := cltest.NewKeyStore(t, db, config.Database()) @@ -91,6 +91,7 @@ func setupLiteORM(t *testing.T) (db *sqlx.DB, orm pipeline.ORM, jorm job.ORM) { } func Test_PipelineORM_CreateSpec(t *testing.T) { + ctx := testutils.Context(t) db, orm, _ := setupLiteORM(t) var ( @@ -102,7 +103,7 @@ func Test_PipelineORM_CreateSpec(t *testing.T) { Source: source, } - id, err := orm.CreateSpec(p, maxTaskDuration) + id, err := orm.CreateSpec(ctx, nil, p, maxTaskDuration) require.NoError(t, err) actual := pipeline.Spec{} @@ -121,7 +122,8 @@ func Test_PipelineORM_FindRun(t *testing.T) { require.NoError(t, err) expected := mustInsertPipelineRun(t, orm) - run, err := orm.FindRun(expected.ID) + ctx := testutils.Context(t) + run, err := orm.FindRun(ctx, expected.ID) require.NoError(t, err) require.Equal(t, expected.ID, run.ID) @@ -138,12 +140,14 @@ func mustInsertPipelineRun(t *testing.T, orm pipeline.ORM) pipeline.Run { FinishedAt: null.Time{}, } - require.NoError(t, orm.InsertRun(&run)) + ctx := testutils.Context(t) + require.NoError(t, orm.InsertRun(ctx, &run)) return run } func mustInsertAsyncRun(t *testing.T, orm pipeline.ORM, jobORM job.ORM) *pipeline.Run { t.Helper() + ctx := testutils.Context(t) s := ` ds1 [type=bridge async=true name="example-bridge" timeout=0 requestData=<{"data": {"coin": "BTC", "market": "USD"}}>] @@ -178,12 +182,13 @@ answer2 [type=bridge name=election_winner index=1]; CreatedAt: time.Now(), } - err = orm.CreateRun(run) + err = orm.CreateRun(ctx, run) require.NoError(t, err) return run } func TestInsertFinishedRuns(t *testing.T) { + ctx := testutils.Context(t) db, orm, _ := setupLiteORM(t) _, err := db.Exec(`SET CONSTRAINTS fk_pipeline_runs_pruning_key DEFERRED`) @@ -207,7 +212,7 @@ func TestInsertFinishedRuns(t *testing.T) { Outputs: jsonserializable.JSONSerializable{}, } - require.NoError(t, orm.InsertRun(&r)) + require.NoError(t, orm.InsertRun(ctx, &r)) r.PipelineTaskRuns = []pipeline.TaskRun{ { @@ -238,12 +243,13 @@ func TestInsertFinishedRuns(t *testing.T) { runs = append(runs, &r) } - err = orm.InsertFinishedRuns(runs, true) + err = orm.InsertFinishedRuns(ctx, runs, true) require.NoError(t, err) } func Test_PipelineORM_InsertFinishedRunWithSpec(t *testing.T) { + ctx := testutils.Context(t) db, orm, jorm := setupLiteORM(t) s := ` @@ -314,7 +320,7 @@ answer2 [type=bridge name=election_winner index=1]; run.AllErrors = append(run.AllErrors, null.NewString("", false)) run.State = pipeline.RunStatusCompleted - err = orm.InsertFinishedRunWithSpec(run, true) + err = orm.InsertFinishedRunWithSpec(ctx, run, true) require.NoError(t, err) var pipelineSpec pipeline.Spec @@ -330,6 +336,7 @@ answer2 [type=bridge name=election_winner index=1]; // Tests that inserting run results, then later updating the run results via upsert will work correctly. func Test_PipelineORM_StoreRun_ShouldUpsert(t *testing.T) { + ctx := testutils.Context(t) _, orm, jorm := setupLiteORM(t) run := mustInsertAsyncRun(t, orm, jorm) @@ -357,14 +364,14 @@ func Test_PipelineORM_StoreRun_ShouldUpsert(t *testing.T) { FinishedAt: null.TimeFrom(now), }, } - restart, err := orm.StoreRun(run) + restart, err := orm.StoreRun(ctx, run) require.NoError(t, err) // no new data, so we don't need a restart require.Equal(t, false, restart) // the run is paused require.Equal(t, pipeline.RunStatusSuspended, run.State) - r, err := orm.FindRun(run.ID) + r, err := orm.FindRun(ctx, run.ID) require.NoError(t, err) run = &r // this is an incomplete run, so partial results should be present (regardless of saveSuccessfulTaskRuns) @@ -388,14 +395,14 @@ func Test_PipelineORM_StoreRun_ShouldUpsert(t *testing.T) { FinishedAt: null.TimeFrom(now), }, } - restart, err = orm.StoreRun(run) + restart, err = orm.StoreRun(ctx, run) require.NoError(t, err) // no new data, so we don't need a restart require.Equal(t, false, restart) // the run is paused require.Equal(t, pipeline.RunStatusSuspended, run.State) - r, err = orm.FindRun(run.ID) + r, err = orm.FindRun(ctx, run.ID) require.NoError(t, err) run = &r // this is an incomplete run, so partial results should be present (regardless of saveSuccessfulTaskRuns) @@ -409,11 +416,12 @@ func Test_PipelineORM_StoreRun_ShouldUpsert(t *testing.T) { // Tests that trying to persist a partial run while new data became available (i.e. via /v2/restart) // will detect a restart and update the result data on the Run. func Test_PipelineORM_StoreRun_DetectsRestarts(t *testing.T) { + ctx := testutils.Context(t) db, orm, jorm := setupLiteORM(t) run := mustInsertAsyncRun(t, orm, jorm) - r, err := orm.FindRun(run.ID) + r, err := orm.FindRun(ctx, run.ID) require.NoError(t, err) require.Equal(t, run.Inputs, r.Inputs) @@ -459,7 +467,7 @@ func Test_PipelineORM_StoreRun_DetectsRestarts(t *testing.T) { }, } - restart, err := orm.StoreRun(run) + restart, err := orm.StoreRun(ctx, run) require.NoError(t, err) // new data available! immediately restart the run require.Equal(t, true, restart) @@ -474,6 +482,7 @@ func Test_PipelineORM_StoreRun_DetectsRestarts(t *testing.T) { } func Test_PipelineORM_StoreRun_UpdateTaskRunResult(t *testing.T) { + ctx := testutils.Context(t) _, orm, jorm := setupLiteORM(t) run := mustInsertAsyncRun(t, orm, jorm) @@ -525,13 +534,13 @@ func Test_PipelineORM_StoreRun_UpdateTaskRunResult(t *testing.T) { require.Equal(t, pipeline.RunStatusRunning, run.State) // Now store a partial run - restart, err := orm.StoreRun(run) + restart, err := orm.StoreRun(ctx, run) require.NoError(t, err) require.False(t, restart) // assert that run should be in "paused" state require.Equal(t, pipeline.RunStatusSuspended, run.State) - r, start, err := orm.UpdateTaskRunResult(ds1_id, pipeline.Result{Value: "foo"}) + r, start, err := orm.UpdateTaskRunResult(ctx, ds1_id, pipeline.Result{Value: "foo"}) run = &r require.NoError(t, err) assert.Greater(t, run.ID, int64(0)) @@ -555,6 +564,7 @@ func Test_PipelineORM_StoreRun_UpdateTaskRunResult(t *testing.T) { } func Test_PipelineORM_DeleteRun(t *testing.T) { + ctx := testutils.Context(t) _, orm, jorm := setupLiteORM(t) run := mustInsertAsyncRun(t, orm, jorm) @@ -582,21 +592,22 @@ func Test_PipelineORM_DeleteRun(t *testing.T) { FinishedAt: null.TimeFrom(now), }, } - restart, err := orm.StoreRun(run) + restart, err := orm.StoreRun(ctx, run) require.NoError(t, err) // no new data, so we don't need a restart require.Equal(t, false, restart) // the run is paused require.Equal(t, pipeline.RunStatusSuspended, run.State) - err = orm.DeleteRun(run.ID) + err = orm.DeleteRun(ctx, run.ID) require.NoError(t, err) - _, err = orm.FindRun(run.ID) + _, err = orm.FindRun(ctx, run.ID) require.Error(t, err, "not found") } func Test_PipelineORM_DeleteRunsOlderThan(t *testing.T) { + ctx := testutils.Context(t) _, orm, jorm := setupHeavyORM(t) var runsIds []int64 @@ -623,7 +634,7 @@ func Test_PipelineORM_DeleteRunsOlderThan(t *testing.T) { run.Outputs = jsonserializable.JSONSerializable{Val: 1, Valid: true} run.AllErrors = pipeline.RunErrors{null.StringFrom("SOMETHING")} - restart, err := orm.StoreRun(run) + restart, err := orm.StoreRun(ctx, run) assert.NoError(t, err) // no new data, so we don't need a restart assert.Equal(t, false, restart) @@ -635,13 +646,14 @@ func Test_PipelineORM_DeleteRunsOlderThan(t *testing.T) { assert.NoError(t, err) for _, runId := range runsIds { - _, err := orm.FindRun(runId) + _, err := orm.FindRun(ctx, runId) require.Error(t, err, "not found") } } func Test_GetUnfinishedRuns_Keepers(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) // The test configures single Keeper job with two running tasks. // GetUnfinishedRuns() expects to catch both running tasks. @@ -650,7 +662,7 @@ func Test_GetUnfinishedRuns_Keepers(t *testing.T) { lggr := logger.TestLogger(t) db := pgtest.NewSqlxDB(t) keyStore := cltest.NewKeyStore(t, db, config.Database()) - porm := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + porm := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgeORM := bridges.NewORM(db) jorm := job.NewORM(db, porm, bridgeORM, keyStore, lggr, config.Database()) @@ -684,7 +696,7 @@ func Test_GetUnfinishedRuns_Keepers(t *testing.T) { runID1 := uuid.New() runID2 := uuid.New() - err = porm.CreateRun(&pipeline.Run{ + err = porm.CreateRun(ctx, &pipeline.Run{ PipelineSpecID: keeperJob.PipelineSpecID, PruningKey: keeperJob.ID, State: pipeline.RunStatusRunning, @@ -701,7 +713,7 @@ func Test_GetUnfinishedRuns_Keepers(t *testing.T) { }) require.NoError(t, err) - err = porm.CreateRun(&pipeline.Run{ + err = porm.CreateRun(ctx, &pipeline.Run{ PipelineSpecID: keeperJob.PipelineSpecID, PruningKey: keeperJob.ID, State: pipeline.RunStatusRunning, @@ -744,6 +756,7 @@ func Test_GetUnfinishedRuns_Keepers(t *testing.T) { func Test_GetUnfinishedRuns_DirectRequest(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) // The test configures single DR job with two task runs: one is running and one is suspended. // GetUnfinishedRuns() expects to catch the one that is running. @@ -752,7 +765,7 @@ func Test_GetUnfinishedRuns_DirectRequest(t *testing.T) { lggr := logger.TestLogger(t) db := pgtest.NewSqlxDB(t) keyStore := cltest.NewKeyStore(t, db, config.Database()) - porm := pipeline.NewORM(db, lggr, config.Database(), config.JobPipeline().MaxSuccessfulRuns()) + porm := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) bridgeORM := bridges.NewORM(db) jorm := job.NewORM(db, porm, bridgeORM, keyStore, lggr, config.Database()) @@ -784,7 +797,7 @@ func Test_GetUnfinishedRuns_DirectRequest(t *testing.T) { runningID := uuid.New() - err = porm.CreateRun(&pipeline.Run{ + err = porm.CreateRun(ctx, &pipeline.Run{ PipelineSpecID: drJob.PipelineSpecID, PruningKey: drJob.ID, State: pipeline.RunStatusRunning, @@ -801,7 +814,7 @@ func Test_GetUnfinishedRuns_DirectRequest(t *testing.T) { }) require.NoError(t, err) - err = porm.CreateRun(&pipeline.Run{ + err = porm.CreateRun(ctx, &pipeline.Run{ PipelineSpecID: drJob.PipelineSpecID, PruningKey: drJob.ID, State: pipeline.RunStatusSuspended, @@ -846,7 +859,7 @@ func Test_Prune(t *testing.T) { }) lggr, observed := logger.TestLoggerObserved(t, zapcore.DebugLevel) db := pgtest.NewSqlxDB(t) - porm := pipeline.NewORM(db, lggr, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + porm := pipeline.NewORM(db, lggr, cfg.JobPipeline().MaxSuccessfulRuns()) torm := newTestORM(porm, db) ps1 := cltest.MustInsertPipelineSpec(t, db) diff --git a/core/services/pipeline/runner.go b/core/services/pipeline/runner.go index 08d371716fc..862d2f49178 100644 --- a/core/services/pipeline/runner.go +++ b/core/services/pipeline/runner.go @@ -15,6 +15,7 @@ import ( "gopkg.in/guregu/null.v4" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" commonutils "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-common/pkg/utils/jsonserializable" @@ -23,7 +24,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/config/env" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/recovery" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/store/models" "github.com/smartcontractkit/chainlink/v2/core/utils" ) @@ -36,15 +36,16 @@ type Runner interface { // Run is a blocking call that will execute the run until no further progress can be made. // If `incomplete` is true, the run is only partially complete and is suspended, awaiting to be resumed when more data comes in. // Note that `saveSuccessfulTaskRuns` value is ignored if the run contains async tasks. - Run(ctx context.Context, run *Run, l logger.Logger, saveSuccessfulTaskRuns bool, fn func(tx pg.Queryer) error) (incomplete bool, err error) - ResumeRun(taskID uuid.UUID, value interface{}, err error) error + Run(ctx context.Context, run *Run, l logger.Logger, saveSuccessfulTaskRuns bool, fn func(tx sqlutil.DataSource) error) (incomplete bool, err error) + ResumeRun(ctx context.Context, taskID uuid.UUID, value interface{}, err error) error // ExecuteRun executes a new run in-memory according to a spec and returns the results. // We expect spec.JobID and spec.JobName to be set for logging/prometheus. ExecuteRun(ctx context.Context, spec Spec, vars Vars, l logger.Logger) (run *Run, trrs TaskRunResults, err error) // InsertFinishedRun saves the run results in the database. - InsertFinishedRun(run *Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error - InsertFinishedRuns(runs []*Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error + // ds is an optional override, for example when executing a transaction. + InsertFinishedRun(ctx context.Context, ds sqlutil.DataSource, run *Run, saveSuccessfulTaskRuns bool) error + InsertFinishedRuns(ctx context.Context, ds sqlutil.DataSource, runs []*Run, saveSuccessfulTaskRuns bool) error // ExecuteAndInsertFinishedRun executes a new run in-memory according to a spec, persists and saves the results. // It is a combination of ExecuteRun and InsertFinishedRun. @@ -566,9 +567,9 @@ func (r *runner) ExecuteAndInsertFinishedRun(ctx context.Context, spec Spec, var } if spec.ID == 0 { - err = r.orm.InsertFinishedRunWithSpec(run, saveSuccessfulTaskRuns) + err = r.orm.InsertFinishedRunWithSpec(ctx, run, saveSuccessfulTaskRuns) } else { - err = r.orm.InsertFinishedRun(run, saveSuccessfulTaskRuns) + err = r.orm.InsertFinishedRun(ctx, run, saveSuccessfulTaskRuns) } if err != nil { return 0, trrs, pkgerrors.Wrapf(err, "error inserting finished results for spec ID %v", run.PipelineSpecID) @@ -577,7 +578,7 @@ func (r *runner) ExecuteAndInsertFinishedRun(ctx context.Context, spec Spec, var } -func (r *runner) Run(ctx context.Context, run *Run, l logger.Logger, saveSuccessfulTaskRuns bool, fn func(tx pg.Queryer) error) (incomplete bool, err error) { +func (r *runner) Run(ctx context.Context, run *Run, l logger.Logger, saveSuccessfulTaskRuns bool, fn func(tx sqlutil.DataSource) error) (incomplete bool, err error) { pipeline, err := r.InitializePipeline(run.PipelineSpec) if err != nil { return false, err @@ -594,8 +595,7 @@ func (r *runner) Run(ctx context.Context, run *Run, l logger.Logger, saveSuccess preinsert := pipeline.RequiresPreInsert() - q := r.orm.GetQ().WithOpts(pg.WithParentCtx(ctx)) - err = q.Transaction(func(tx pg.Queryer) error { + err = r.orm.Transact(ctx, func(tx ORM) error { // OPTIMISATION: avoid an extra db write if there is no async tasks present or if this is a resumed run if preinsert && run.ID == 0 { now := time.Now() @@ -614,13 +614,13 @@ func (r *runner) Run(ctx context.Context, run *Run, l logger.Logger, saveSuccess default: } } - if err = r.orm.CreateRun(run, pg.WithQueryer(tx)); err != nil { + if err = tx.CreateRun(ctx, run); err != nil { return err } } if fn != nil { - return fn(tx) + return fn(tx.DataSource()) } return nil }) @@ -634,14 +634,14 @@ func (r *runner) Run(ctx context.Context, run *Run, l logger.Logger, saveSuccess if preinsert { // FailSilently = run failed and task was marked failEarly. skip StoreRun and instead delete all trace of it if run.FailSilently { - if err = r.orm.DeleteRun(run.ID); err != nil { + if err = r.orm.DeleteRun(ctx, run.ID); err != nil { return false, pkgerrors.Wrap(err, "Run") } return false, nil } var restart bool - restart, err = r.orm.StoreRun(run) + restart, err = r.orm.StoreRun(ctx, run) if err != nil { return false, pkgerrors.Wrapf(err, "error storing run for spec ID %v state %v outputs %v errors %v finished_at %v", run.PipelineSpec.ID, run.State, run.Outputs, run.FatalErrors, run.FinishedAt) @@ -660,7 +660,7 @@ func (r *runner) Run(ctx context.Context, run *Run, l logger.Logger, saveSuccess return false, nil } - if err = r.orm.InsertFinishedRun(run, saveSuccessfulTaskRuns, pg.WithParentCtx(ctx)); err != nil { + if err = r.orm.InsertFinishedRun(ctx, run, saveSuccessfulTaskRuns); err != nil { return false, pkgerrors.Wrapf(err, "error storing run for spec ID %v", run.PipelineSpec.ID) } } @@ -671,8 +671,8 @@ func (r *runner) Run(ctx context.Context, run *Run, l logger.Logger, saveSuccess } } -func (r *runner) ResumeRun(taskID uuid.UUID, value interface{}, err error) error { - run, start, err := r.orm.UpdateTaskRunResult(taskID, Result{ +func (r *runner) ResumeRun(ctx context.Context, taskID uuid.UUID, value interface{}, err error) error { + run, start, err := r.orm.UpdateTaskRunResult(ctx, taskID, Result{ Value: value, Error: err, }) @@ -694,12 +694,20 @@ func (r *runner) ResumeRun(taskID uuid.UUID, value interface{}, err error) error return nil } -func (r *runner) InsertFinishedRun(run *Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error { - return r.orm.InsertFinishedRun(run, saveSuccessfulTaskRuns, qopts...) +func (r *runner) InsertFinishedRun(ctx context.Context, ds sqlutil.DataSource, run *Run, saveSuccessfulTaskRuns bool) error { + orm := r.orm + if ds != nil { + orm = orm.WithDataSource(ds) + } + return orm.InsertFinishedRun(ctx, run, saveSuccessfulTaskRuns) } -func (r *runner) InsertFinishedRuns(runs []*Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error { - return r.orm.InsertFinishedRuns(runs, saveSuccessfulTaskRuns, qopts...) +func (r *runner) InsertFinishedRuns(ctx context.Context, ds sqlutil.DataSource, runs []*Run, saveSuccessfulTaskRuns bool) error { + orm := r.orm + if ds != nil { + orm = orm.WithDataSource(ds) + } + return orm.InsertFinishedRuns(ctx, runs, saveSuccessfulTaskRuns) } func (r *runner) runReaper() { diff --git a/core/services/pipeline/runner_test.go b/core/services/pipeline/runner_test.go index 52e668339ec..f27a6b35348 100644 --- a/core/services/pipeline/runner_test.go +++ b/core/services/pipeline/runner_test.go @@ -476,7 +476,7 @@ func Test_PipelineRunner_HandleFaultsPersistRun(t *testing.T) { orm.On("GetQ").Return(q).Maybe() orm.On("InsertFinishedRun", mock.Anything, mock.Anything, mock.Anything). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = 1 + args.Get(1).(*pipeline.Run).ID = 1 }). Return(nil) cfg := configtest.NewTestGeneralConfig(t) @@ -517,7 +517,7 @@ func Test_PipelineRunner_ExecuteAndInsertFinishedRun_SavingTheSpec(t *testing.T) orm.On("GetQ").Return(q).Maybe() orm.On("InsertFinishedRunWithSpec", mock.Anything, mock.Anything, mock.Anything). Run(func(args mock.Arguments) { - args.Get(0).(*pipeline.Run).ID = 1 + args.Get(1).(*pipeline.Run).ID = 1 }). Return(nil) cfg := configtest.NewTestGeneralConfig(t) @@ -642,7 +642,13 @@ func Test_PipelineRunner_AsyncJob_Basic(t *testing.T) { btORM := bridgesMocks.NewORM(t) btORM.On("FindBridge", mock.Anything, bt.Name).Return(*bt, nil) + r, orm := newRunner(t, db, btORM, cfg) + transactCall := orm.On("Transact", mock.Anything, mock.Anything) + transactCall.Run(func(args mock.Arguments) { + fn := args[1].(func(orm pipeline.ORM) error) + transactCall.ReturnArguments = mock.Arguments{fn(orm)} + }) s := fmt.Sprintf(` ds1 [type=bridge async=true name="%s" timeout=0 requestData=<{"data": {"coin": "BTC", "market": "USD"}}>] @@ -673,11 +679,11 @@ ds5 [type=http method="GET" url="%s" index=2] // Start a new run run := pipeline.NewRun(spec, pipeline.NewVarsFrom(nil)) // we should receive a call to CreateRun because it's contains an async task - orm.On("CreateRun", mock.AnythingOfType("*pipeline.Run"), mock.Anything).Return(nil).Run(func(args mock.Arguments) { - run := args.Get(0).(*pipeline.Run) + orm.On("CreateRun", mock.Anything, mock.AnythingOfType("*pipeline.Run")).Return(nil).Run(func(args mock.Arguments) { + run := args.Get(1).(*pipeline.Run) run.ID = 1 // give it a valid "id" }).Once() - orm.On("StoreRun", mock.AnythingOfType("*pipeline.Run"), mock.Anything).Return(false, nil).Once() + orm.On("StoreRun", mock.Anything, mock.AnythingOfType("*pipeline.Run")).Return(false, nil).Once() lggr := logger.TestLogger(t) incomplete, err := r.Run(testutils.Context(t), run, lggr, false, nil) require.NoError(t, err) @@ -687,7 +693,7 @@ ds5 [type=http method="GET" url="%s" index=2] // TODO: test a pending run that's not marked async=true, that is not allowed // Trigger run resumption with no new data - orm.On("StoreRun", mock.AnythingOfType("*pipeline.Run")).Return(false, nil).Once() + orm.On("StoreRun", mock.Anything, mock.AnythingOfType("*pipeline.Run")).Return(false, nil).Once() incomplete, err = r.Run(testutils.Context(t), run, lggr, false, nil) require.NoError(t, err) require.Equal(t, true, incomplete) // still incomplete @@ -700,7 +706,7 @@ ds5 [type=http method="GET" url="%s" index=2] Valid: true, } // Trigger run resumption - orm.On("StoreRun", mock.AnythingOfType("*pipeline.Run"), mock.Anything).Return(false, nil).Once() + orm.On("StoreRun", mock.Anything, mock.AnythingOfType("*pipeline.Run")).Return(false, nil).Once() incomplete, err = r.Run(testutils.Context(t), run, lggr, false, nil) require.NoError(t, err) require.Equal(t, false, incomplete) // done @@ -770,6 +776,11 @@ func Test_PipelineRunner_AsyncJob_InstantRestart(t *testing.T) { btORM.On("FindBridge", mock.Anything, bt.Name).Return(*bt, nil) r, orm := newRunner(t, db, btORM, cfg) + transactCall := orm.On("Transact", mock.Anything, mock.Anything) + transactCall.Run(func(args mock.Arguments) { + fn := args[1].(func(orm pipeline.ORM) error) + transactCall.ReturnArguments = mock.Arguments{fn(orm)} + }) s := fmt.Sprintf(` ds1 [type=bridge async=true name="%s" timeout=0 requestData=<{"data": {"coin": "BTC", "market": "USD"}}>] @@ -800,13 +811,13 @@ ds5 [type=http method="GET" url="%s" index=2] // Start a new run run := pipeline.NewRun(spec, pipeline.NewVarsFrom(nil)) // we should receive a call to CreateRun because it's contains an async task - orm.On("CreateRun", mock.AnythingOfType("*pipeline.Run"), mock.Anything).Return(nil).Run(func(args mock.Arguments) { - run := args.Get(0).(*pipeline.Run) + orm.On("CreateRun", mock.Anything, mock.AnythingOfType("*pipeline.Run")).Return(nil).Run(func(args mock.Arguments) { + run := args.Get(1).(*pipeline.Run) run.ID = 1 // give it a valid "id" }).Once() // Simulate updated task run data - orm.On("StoreRun", mock.AnythingOfType("*pipeline.Run"), mock.Anything).Return(true, nil).Run(func(args mock.Arguments) { - run := args.Get(0).(*pipeline.Run) + orm.On("StoreRun", mock.Anything, mock.AnythingOfType("*pipeline.Run")).Return(true, nil).Run(func(args mock.Arguments) { + run := args.Get(1).(*pipeline.Run) // Now simulate a new result coming in while we were running task := run.ByDotID("ds1") task.Error = null.NewString("", false) @@ -816,7 +827,7 @@ ds5 [type=http method="GET" url="%s" index=2] } }).Once() // StoreRun is called again to store the final result - orm.On("StoreRun", mock.AnythingOfType("*pipeline.Run"), mock.Anything).Return(false, nil).Once() + orm.On("StoreRun", mock.Anything, mock.AnythingOfType("*pipeline.Run")).Return(false, nil).Once() incomplete, err := r.Run(testutils.Context(t), run, logger.TestLogger(t), false, nil) require.NoError(t, err) require.Len(t, run.PipelineTaskRuns, 12) diff --git a/core/services/pipeline/task.bridge_test.go b/core/services/pipeline/task.bridge_test.go index 922f82a533b..029c6c78ca8 100644 --- a/core/services/pipeline/task.bridge_test.go +++ b/core/services/pipeline/task.bridge_test.go @@ -216,8 +216,8 @@ func TestBridgeTask_Happy(t *testing.T) { RequestData: btcUSDPairing, } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -258,8 +258,8 @@ func TestBridgeTask_HandlesIntermittentFailure(t *testing.T) { CacheTTL: "30s", // standard duration string format } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) result, runInfo := task.Run(testutils.Context(t), logger.TestLogger(t), @@ -321,8 +321,8 @@ func TestBridgeTask_DoesNotReturnStaleResults(t *testing.T) { RequestData: btcUSDPairing, } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -481,8 +481,8 @@ func TestBridgeTask_AsyncJobPendingState(t *testing.T) { Async: "true", } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, id, c) @@ -659,8 +659,8 @@ func TestBridgeTask_Variables(t *testing.T) { IncludeInputAtKey: test.includeInputAtKey, } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -728,8 +728,8 @@ func TestBridgeTask_Meta(t *testing.T) { Name: bridge.Name.String(), } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -782,8 +782,8 @@ func TestBridgeTask_IncludeInputAtKey(t *testing.T) { IncludeInputAtKey: test.includeInputAtKey, } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -838,8 +838,8 @@ func TestBridgeTask_ErrorMessage(t *testing.T) { RequestData: ethUSDPairing, } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -877,8 +877,8 @@ func TestBridgeTask_OnlyErrorMessage(t *testing.T) { RequestData: ethUSDPairing, } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -902,8 +902,8 @@ func TestBridgeTask_ErrorIfBridgeMissing(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() orm := bridges.NewORM(db) - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -992,8 +992,8 @@ func TestBridgeTask_Headers(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -1014,8 +1014,8 @@ func TestBridgeTask_Headers(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -1036,8 +1036,8 @@ func TestBridgeTask_Headers(t *testing.T) { } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) @@ -1082,8 +1082,8 @@ func TestBridgeTask_AdapterResponseStatusFailure(t *testing.T) { RequestData: btcUSDPairing, } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) diff --git a/core/services/pipeline/task.http_test.go b/core/services/pipeline/task.http_test.go index ce28fac478c..6264d1e591b 100644 --- a/core/services/pipeline/task.http_test.go +++ b/core/services/pipeline/task.http_test.go @@ -24,7 +24,6 @@ import ( clhttptest "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/httptest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/store/models" "github.com/smartcontractkit/chainlink/v2/core/utils" @@ -177,8 +176,8 @@ func TestHTTPTask_Variables(t *testing.T) { RequestData: test.requestData, } c := clhttptest.NewTestLocalOnlyHTTPClient() - trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - specID, err := trORM.CreateSpec(pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute), pg.WithParentCtx(testutils.Context(t))) + trORM := pipeline.NewORM(db, logger.TestLogger(t), cfg.JobPipeline().MaxSuccessfulRuns()) + specID, err := trORM.CreateSpec(testutils.Context(t), nil, pipeline.Pipeline{}, *models.NewInterval(5 * time.Minute)) require.NoError(t, err) task.HelperSetDependencies(cfg.JobPipeline(), cfg.WebServer(), orm, specID, uuid.UUID{}, c) diff --git a/core/services/relay/evm/mocks/request_round_db.go b/core/services/relay/evm/mocks/request_round_db.go index 725fc6e6b37..4168ba4a1b0 100644 --- a/core/services/relay/evm/mocks/request_round_db.go +++ b/core/services/relay/evm/mocks/request_round_db.go @@ -9,6 +9,8 @@ import ( mock "github.com/stretchr/testify/mock" ocr2aggregator "github.com/smartcontractkit/libocr/gethwrappers2/ocr2aggregator" + + sqlutil "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" ) // RequestRoundDB is an autogenerated mock type for the RequestRoundDB type @@ -62,19 +64,21 @@ func (_m *RequestRoundDB) SaveLatestRoundRequested(ctx context.Context, rr ocr2a return r0 } -// Transact provides a mock function with given fields: _a0, _a1 -func (_m *RequestRoundDB) Transact(_a0 context.Context, _a1 func(evm.RequestRoundDB) error) error { - ret := _m.Called(_a0, _a1) +// WithDataSource provides a mock function with given fields: _a0 +func (_m *RequestRoundDB) WithDataSource(_a0 sqlutil.DataSource) evm.RequestRoundDB { + ret := _m.Called(_a0) if len(ret) == 0 { - panic("no return value specified for Transact") + panic("no return value specified for WithDataSource") } - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, func(evm.RequestRoundDB) error) error); ok { - r0 = rf(_a0, _a1) + var r0 evm.RequestRoundDB + if rf, ok := ret.Get(0).(func(sqlutil.DataSource) evm.RequestRoundDB); ok { + r0 = rf(_a0) } else { - r0 = ret.Error(0) + if ret.Get(0) != nil { + r0 = ret.Get(0).(evm.RequestRoundDB) + } } return r0 diff --git a/core/services/relay/evm/request_round_db.go b/core/services/relay/evm/request_round_db.go index 2b6ae10782d..96c5a05d1c7 100644 --- a/core/services/relay/evm/request_round_db.go +++ b/core/services/relay/evm/request_round_db.go @@ -12,16 +12,17 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" ) +//go:generate mockery --quiet --name RequestRoundDB --output ./mocks/ --case=underscore + // RequestRoundDB stores requested rounds for querying by the median plugin. type RequestRoundDB interface { SaveLatestRoundRequested(ctx context.Context, rr ocr2aggregator.OCR2AggregatorRoundRequested) error LoadLatestRoundRequested(context.Context) (rr ocr2aggregator.OCR2AggregatorRoundRequested, err error) - Transact(context.Context, func(db RequestRoundDB) error) error + WithDataSource(sqlutil.DataSource) RequestRoundDB } var _ RequestRoundDB = &requestRoundDB{} -//go:generate mockery --quiet --name RequestRoundDB --output ./mocks/ --case=underscore type requestRoundDB struct { ds sqlutil.DataSource oracleSpecID int32 @@ -33,10 +34,8 @@ func NewRoundRequestedDB(ds sqlutil.DataSource, oracleSpecID int32, lggr logger. return &requestRoundDB{ds, oracleSpecID, lggr} } -func (d *requestRoundDB) Transact(ctx context.Context, fn func(db RequestRoundDB) error) error { - return sqlutil.Transact(ctx, func(ds sqlutil.DataSource) RequestRoundDB { - return NewRoundRequestedDB(ds, d.oracleSpecID, d.lggr) - }, d.ds, nil, fn) +func (d *requestRoundDB) WithDataSource(ds sqlutil.DataSource) RequestRoundDB { + return NewRoundRequestedDB(ds, d.oracleSpecID, d.lggr) } func (d *requestRoundDB) SaveLatestRoundRequested(ctx context.Context, rr ocr2aggregator.OCR2AggregatorRoundRequested) error { diff --git a/core/services/relay/evm/request_round_db_test.go b/core/services/relay/evm/request_round_db_test.go index 10932c4e229..26f8e2ac1a6 100644 --- a/core/services/relay/evm/request_round_db_test.go +++ b/core/services/relay/evm/request_round_db_test.go @@ -37,9 +37,7 @@ func Test_DB_LatestRoundRequested(t *testing.T) { t.Run("saves latest round requested", func(t *testing.T) { ctx := testutils.Context(t) - err := db.Transact(ctx, func(tx evm.RequestRoundDB) error { - return tx.SaveLatestRoundRequested(ctx, rr) - }) + err := db.SaveLatestRoundRequested(ctx, rr) require.NoError(t, err) rawLog.Index = 42 @@ -53,9 +51,7 @@ func Test_DB_LatestRoundRequested(t *testing.T) { Raw: rawLog, } - err = db.Transact(ctx, func(tx evm.RequestRoundDB) error { - return tx.SaveLatestRoundRequested(ctx, rr) - }) + err = db.SaveLatestRoundRequested(ctx, rr) require.NoError(t, err) }) diff --git a/core/services/relay/evm/request_round_tracker.go b/core/services/relay/evm/request_round_tracker.go index bb39271f278..fe6b6826eb2 100644 --- a/core/services/relay/evm/request_round_tracker.go +++ b/core/services/relay/evm/request_round_tracker.go @@ -106,8 +106,8 @@ func (t *RequestRoundTracker) Close() error { // HandleLog complies with LogListener interface // It is not thread safe -func (t *RequestRoundTracker) HandleLog(lb log.Broadcast) { - was, err := t.logBroadcaster.WasAlreadyConsumed(t.ctx, lb) +func (t *RequestRoundTracker) HandleLog(ctx context.Context, lb log.Broadcast) { + was, err := t.logBroadcaster.WasAlreadyConsumed(ctx, lb) if err != nil { t.lggr.Errorw("OCRContract: could not determine if log was already consumed", "err", err) return @@ -118,12 +118,12 @@ func (t *RequestRoundTracker) HandleLog(lb log.Broadcast) { raw := lb.RawLog() if raw.Address != t.contract.Address() { t.lggr.Errorf("log address of 0x%x does not match configured contract address of 0x%x", raw.Address, t.contract.Address()) - t.lggr.ErrorIf(t.logBroadcaster.MarkConsumed(t.ctx, lb), "unable to mark consumed") + t.lggr.ErrorIf(t.logBroadcaster.MarkConsumed(ctx, nil, lb), "unable to mark consumed") return } topics := raw.Topics if len(topics) == 0 { - t.lggr.ErrorIf(t.logBroadcaster.MarkConsumed(t.ctx, lb), "unable to mark consumed") + t.lggr.ErrorIf(t.logBroadcaster.MarkConsumed(ctx, nil, lb), "unable to mark consumed") return } @@ -134,16 +134,15 @@ func (t *RequestRoundTracker) HandleLog(lb log.Broadcast) { rr, err = t.contractFilterer.ParseRoundRequested(raw) if err != nil { t.lggr.Errorw("could not parse round requested", "err", err) - t.lggr.ErrorIf(t.logBroadcaster.MarkConsumed(t.ctx, lb), "unable to mark consumed") + t.lggr.ErrorIf(t.logBroadcaster.MarkConsumed(ctx, nil, lb), "unable to mark consumed") return } if IsLaterThan(raw, t.latestRoundRequested.Raw) { - ctx := context.TODO() //TODO https://smartcontract-it.atlassian.net/browse/BCF-2887 - err = t.odb.Transact(ctx, func(tx RequestRoundDB) error { - if err = tx.SaveLatestRoundRequested(ctx, *rr); err != nil { + err = sqlutil.TransactDataSource(ctx, t.ds, nil, func(tx sqlutil.DataSource) error { + if err = t.odb.WithDataSource(tx).SaveLatestRoundRequested(ctx, *rr); err != nil { return err } - return t.logBroadcaster.MarkConsumed(t.ctx, lb) + return t.logBroadcaster.MarkConsumed(ctx, tx, lb) }) if err != nil { t.lggr.Error(err) @@ -161,7 +160,7 @@ func (t *RequestRoundTracker) HandleLog(lb log.Broadcast) { t.lggr.Debugw("RequestRoundTracker: got unrecognised log topic", "topic", topics[0]) } if !consumed { - t.lggr.ErrorIf(t.logBroadcaster.MarkConsumed(t.ctx, lb), "unable to mark consumed") + t.lggr.ErrorIf(t.logBroadcaster.MarkConsumed(ctx, nil, lb), "unable to mark consumed") } } diff --git a/core/services/relay/evm/request_round_tracker_test.go b/core/services/relay/evm/request_round_tracker_test.go index 9feb4b77348..3421004ccf5 100644 --- a/core/services/relay/evm/request_round_tracker_test.go +++ b/core/services/relay/evm/request_round_tracker_test.go @@ -14,6 +14,7 @@ import ( "github.com/smartcontractkit/libocr/gethwrappers2/ocr2aggregator" ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" htmocks "github.com/smartcontractkit/chainlink/v2/common/headtracker/mocks" evmclimocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client/mocks" evmconfig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" @@ -112,7 +113,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin rawLog := cltest.LogFromFixture(t, "../../../testdata/jsonrpc/ocr2_round_requested_log_1_1.json") logBroadcast.On("RawLog").Return(rawLog).Maybe() logBroadcast.On("String").Return("").Maybe() - uni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) configDigest, epoch, round, err := uni.requestRoundTracker.LatestRoundRequested(testutils.Context(t), 0) @@ -121,7 +122,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin require.Equal(t, 0, int(round)) require.Equal(t, 0, int(epoch)) - uni.requestRoundTracker.HandleLog(logBroadcast) + uni.requestRoundTracker.HandleLog(tests.Context(t), logBroadcast) configDigest, epoch, round, err = uni.requestRoundTracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -143,7 +144,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin require.Equal(t, 0, int(round)) require.Equal(t, 0, int(epoch)) - uni.requestRoundTracker.HandleLog(logBroadcast) + uni.requestRoundTracker.HandleLog(tests.Context(t), logBroadcast) configDigest, epoch, round, err = uni.requestRoundTracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -168,19 +169,14 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin logBroadcast.On("RawLog").Return(rawLog).Maybe() logBroadcast.On("String").Return("").Maybe() uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - uni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.MatchedBy(func(rr ocr2aggregator.OCR2AggregatorRoundRequested) bool { return rr.Epoch == 1 && rr.Round == 1 })).Return(nil) - transact := uni.db.On("Transact", mock.Anything, mock.Anything) - transact.Run(func(args mock.Arguments) { - fn := args[1].(func(evm.RequestRoundDB) error) - err2 := fn(uni.db) - transact.ReturnArguments = []any{err2} - }) + uni.db.On("WithDataSource", mock.Anything).Return(uni.db) - uni.requestRoundTracker.HandleLog(logBroadcast) + uni.requestRoundTracker.HandleLog(tests.Context(t), logBroadcast) configDigest, epoch, round, err = uni.requestRoundTracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -194,13 +190,13 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin logBroadcast2.On("RawLog").Return(rawLog2).Maybe() logBroadcast2.On("String").Return("").Maybe() uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - uni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.MatchedBy(func(rr ocr2aggregator.OCR2AggregatorRoundRequested) bool { return rr.Epoch == 1 && rr.Round == 9 })).Return(nil) - uni.requestRoundTracker.HandleLog(logBroadcast2) + uni.requestRoundTracker.HandleLog(tests.Context(t), logBroadcast2) configDigest, epoch, round, err = uni.requestRoundTracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -209,7 +205,7 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin assert.Equal(t, 9, int(round)) // Same round with lower epoch is ignored - uni.requestRoundTracker.HandleLog(logBroadcast) + uni.requestRoundTracker.HandleLog(tests.Context(t), logBroadcast) configDigest, epoch, round, err = uni.requestRoundTracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -224,13 +220,13 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin logBroadcast3.On("RawLog").Return(rawLog3).Maybe() logBroadcast3.On("String").Return("").Maybe() uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - uni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil) + uni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil) uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.MatchedBy(func(rr ocr2aggregator.OCR2AggregatorRoundRequested) bool { return rr.Epoch == 2 && rr.Round == 1 })).Return(nil) - uni.requestRoundTracker.HandleLog(logBroadcast3) + uni.requestRoundTracker.HandleLog(tests.Context(t), logBroadcast3) configDigest, epoch, round, err = uni.requestRoundTracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) @@ -250,14 +246,9 @@ func Test_OCRContractTracker_HandleLog_OCRContractLatestRoundRequested(t *testin uni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) uni.db.On("SaveLatestRoundRequested", mock.Anything, mock.Anything).Return(errors.New("something exploded")) - transact := uni.db.On("Transact", mock.Anything, mock.Anything) - transact.Run(func(args mock.Arguments) { - fn := args[1].(func(evm.RequestRoundDB) error) - err := fn(uni.db) - transact.ReturnArguments = []any{err} - }) + uni.db.On("WithDataSource", mock.Anything).Return(uni.db) - uni.requestRoundTracker.HandleLog(logBroadcast) + uni.requestRoundTracker.HandleLog(tests.Context(t), logBroadcast) configDigest, epoch, round, err := uni.requestRoundTracker.LatestRoundRequested(testutils.Context(t), 0) require.NoError(t, err) diff --git a/core/services/streams/delegate.go b/core/services/streams/delegate.go index f9e2a64c4a3..bf492d4bd15 100644 --- a/core/services/streams/delegate.go +++ b/core/services/streams/delegate.go @@ -12,7 +12,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/ocrcommon" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" ) @@ -38,10 +37,10 @@ func (d *Delegate) JobType() job.Type { return job.Stream } -func (d *Delegate) BeforeJobCreated(jb job.Job) {} -func (d *Delegate) AfterJobCreated(jb job.Job) {} -func (d *Delegate) BeforeJobDeleted(jb job.Job) {} -func (d *Delegate) OnDeleteJob(ctx context.Context, jb job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) BeforeJobCreated(jb job.Job) {} +func (d *Delegate) AfterJobCreated(jb job.Job) {} +func (d *Delegate) BeforeJobDeleted(jb job.Job) {} +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services []job.ServiceCtx, err error) { if jb.StreamID == nil { diff --git a/core/services/streams/stream_test.go b/core/services/streams/stream_test.go index 3c0b4d0721f..3e8f58cd58b 100644 --- a/core/services/streams/stream_test.go +++ b/core/services/streams/stream_test.go @@ -11,9 +11,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" ) @@ -32,7 +32,7 @@ func (m *mockRunner) ExecuteRun(ctx context.Context, spec pipeline.Spec, vars pi func (m *mockRunner) InitializePipeline(spec pipeline.Spec) (p *pipeline.Pipeline, err error) { return m.p, m.err } -func (m *mockRunner) InsertFinishedRun(run *pipeline.Run, saveSuccessfulTaskRuns bool, qopts ...pg.QOpt) error { +func (m *mockRunner) InsertFinishedRun(ctx context.Context, ds sqlutil.DataSource, run *pipeline.Run, saveSuccessfulTaskRuns bool) error { return m.err } diff --git a/core/services/vrf/delegate.go b/core/services/vrf/delegate.go index 617a28ac4d5..84c5126afef 100644 --- a/core/services/vrf/delegate.go +++ b/core/services/vrf/delegate.go @@ -11,8 +11,7 @@ import ( "github.com/theodesp/go-heaps/pairing" "go.uber.org/multierr" - "github.com/jmoiron/sqlx" - + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/log" @@ -26,7 +25,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" v1 "github.com/smartcontractkit/chainlink/v2/core/services/vrf/v1" v2 "github.com/smartcontractkit/chainlink/v2/core/services/vrf/v2" @@ -34,7 +32,7 @@ import ( ) type Delegate struct { - q pg.Q + ds sqlutil.DataSource pr pipeline.Runner porm pipeline.ORM ks keystore.Master @@ -44,16 +42,15 @@ type Delegate struct { } func NewDelegate( - db *sqlx.DB, + ds sqlutil.DataSource, ks keystore.Master, pr pipeline.Runner, porm pipeline.ORM, legacyChains legacyevm.LegacyChainContainer, lggr logger.Logger, - cfg pg.QConfig, mailMon *mailbox.Monitor) *Delegate { return &Delegate{ - q: pg.NewQ(db, lggr, cfg), + ds: ds, ks: ks, pr: pr, porm: porm, @@ -67,10 +64,10 @@ func (d *Delegate) JobType() job.Type { return job.VRF } -func (d *Delegate) BeforeJobCreated(job.Job) {} -func (d *Delegate) AfterJobCreated(job.Job) {} -func (d *Delegate) BeforeJobDeleted(job.Job) {} -func (d *Delegate) OnDeleteJob(context.Context, job.Job, pg.Queryer) error { return nil } +func (d *Delegate) BeforeJobCreated(job.Job) {} +func (d *Delegate) AfterJobCreated(job.Job) {} +func (d *Delegate) BeforeJobDeleted(job.Job) {} +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // ServicesForSpec satisfies the job.Delegate interface. func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.ServiceCtx, error) { @@ -171,7 +168,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.Servi lV2Plus, chain, chain.ID(), - d.q, + d.ds, v2.NewCoordinatorV2_5(coordinatorV2Plus), batchCoordinatorV2, vrfOwner, @@ -225,7 +222,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.Servi lV2, chain, chain.ID(), - d.q, + d.ds, v2.NewCoordinatorV2(coordinatorV2), batchCoordinatorV2, vrfOwner, @@ -246,7 +243,6 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.Servi Cfg: chain.Config().EVM(), FeeCfg: chain.Config().EVM().GasEstimator(), L: logger.Sugared(lV1), - Q: d.q, Coordinator: coordinator, PipelineRunner: d.pr, GethKs: d.ks.Eth(), diff --git a/core/services/vrf/delegate_test.go b/core/services/vrf/delegate_test.go index d009641e65f..db9724179e7 100644 --- a/core/services/vrf/delegate_test.go +++ b/core/services/vrf/delegate_test.go @@ -78,7 +78,7 @@ func buildVrfUni(t *testing.T, db *sqlx.DB, cfg chainlink.GeneralConfig) vrfUniv hb := headtracker.NewHeadBroadcaster(lggr) // Don't mock db interactions - prm := pipeline.NewORM(db, lggr, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) + prm := pipeline.NewORM(db, lggr, cfg.JobPipeline().MaxSuccessfulRuns()) btORM := bridges.NewORM(db) ks := keystore.NewInMemory(db, utils.FastScryptParams, lggr, cfg.Database()) _, dbConfig, evmConfig := txmgr.MakeTestConfigs(t) @@ -160,7 +160,6 @@ func setup(t *testing.T) (vrfUniverse, *v1.Listener, job.Job) { vuni.prm, vuni.legacyChains, logger.TestLogger(t), - cfg.Database(), mailMon) vs := testspecs.GenerateVRFSpec(testspecs.VRFSpecParams{PublicKey: vuni.vrfkey.PublicKey.String(), EVMChainID: testutils.FixtureChainID.String()}) jb, err := vrfcommon.ValidatedVRFSpec(vs.Toml()) @@ -201,9 +200,10 @@ func TestDelegate_ReorgAttackProtection(t *testing.T) { preSeed := common.BigToHash(big.NewInt(42)).Bytes() txHash := evmutils.NewHash() vuni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil).Maybe() - vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Return(nil).Maybe() + vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() vuni.ec.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Return(generateCallbackReturnValues(t, false), nil).Maybe() - listener.HandleLog(log.NewLogBroadcast(types.Log{ + ctx := testutils.Context(t) + listener.HandleLog(ctx, log.NewLogBroadcast(types.Log{ // Data has all the NON-indexed parameters Data: bytes.Join([][]byte{pk.MustHash().Bytes(), // key hash preSeed, // preSeed @@ -302,14 +302,15 @@ func TestDelegate_ValidLog(t *testing.T) { consumed := make(chan struct{}) for i, tc := range tt { tc := tc + ctx := testutils.Context(t) vuni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { consumed <- struct{}{} }).Return(nil).Once() // Expect a call to check if the req is already fulfilled. vuni.ec.On("CallContract", mock.Anything, mock.Anything, mock.Anything).Return(generateCallbackReturnValues(t, false), nil) - listener.HandleLog(log.NewLogBroadcast(tc.log, vuni.cid, nil)) + listener.HandleLog(ctx, log.NewLogBroadcast(tc.log, vuni.cid, nil)) // Wait until the log is present waitForChannel(t, added, time.Second, "request not added to the queue") // Feed it a head which confirms it. @@ -318,7 +319,7 @@ func TestDelegate_ValidLog(t *testing.T) { // Ensure we created a successful run. waitForChannel(t, runComplete, 2*time.Second, "pipeline not complete") - runs, err := vuni.prm.GetAllRuns() + runs, err := vuni.prm.GetAllRuns(ctx) require.NoError(t, err) require.Equal(t, i+1, len(runs)) assert.False(t, runs[0].FatalErrors.HasError()) @@ -328,13 +329,13 @@ func TestDelegate_ValidLog(t *testing.T) { p, err := vuni.ks.VRF().GenerateProof(keyID, evmutils.MustHash(string(bytes.Join([][]byte{preSeed, bh.Bytes()}, []byte{}))).Big()) require.NoError(t, err) vuni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) - vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { consumed <- struct{}{} }).Return(nil).Once() // If we send a completed log we should the respCount increase var reqIDBytes []byte copy(reqIDBytes[:], tc.reqID[:]) - listener.HandleLog(log.NewLogBroadcast(types.Log{ + listener.HandleLog(ctx, log.NewLogBroadcast(types.Log{ // Data has all the NON-indexed parameters Data: bytes.Join([][]byte{reqIDBytes, // output p.Output.Bytes(), @@ -354,7 +355,7 @@ func TestDelegate_InvalidLog(t *testing.T) { vuni, listener, jb := setup(t) vuni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) done := make(chan struct{}) - vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { done <- struct{}{} }).Return(nil).Once() // Expect a call to check if the req is already fulfilled. @@ -365,7 +366,8 @@ func TestDelegate_InvalidLog(t *testing.T) { added <- struct{}{} }) // Send an invalid log (keyhash doesnt match) - listener.HandleLog(log.NewLogBroadcast(types.Log{ + ctx := testutils.Context(t) + listener.HandleLog(ctx, log.NewLogBroadcast(types.Log{ // Data has all the NON-indexed parameters Data: append(append(append(append( evmutils.NewHash().Bytes(), // key hash @@ -392,7 +394,7 @@ func TestDelegate_InvalidLog(t *testing.T) { waitForChannel(t, done, time.Second, "log not consumed") // Should create a run that errors in the vrf task - runs, err := vuni.prm.GetAllRuns() + runs, err := vuni.prm.GetAllRuns(ctx) require.NoError(t, err) require.Equal(t, len(runs), 1) for _, tr := range runs[0].PipelineTaskRuns { @@ -417,7 +419,7 @@ func TestFulfilledCheck(t *testing.T) { vuni, listener, jb := setup(t) vuni.lb.On("WasAlreadyConsumed", mock.Anything, mock.Anything).Return(false, nil) done := make(chan struct{}) - vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + vuni.lb.On("MarkConsumed", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { done <- struct{}{} }).Return(nil).Once() // Expect a call to check if the req is already fulfilled. @@ -429,7 +431,8 @@ func TestFulfilledCheck(t *testing.T) { added <- struct{}{} }) // Send an invalid log (keyhash doesn't match) - listener.HandleLog(log.NewLogBroadcast( + ctx := testutils.Context(t) + listener.HandleLog(ctx, log.NewLogBroadcast( types.Log{ // Data has all the NON-indexed parameters Data: bytes.Join([][]byte{ @@ -455,7 +458,7 @@ func TestFulfilledCheck(t *testing.T) { waitForChannel(t, done, time.Second, "log not consumed") // Should consume the log with no run - runs, err := vuni.prm.GetAllRuns() + runs, err := vuni.prm.GetAllRuns(ctx) require.NoError(t, err) require.Equal(t, len(runs), 0) } @@ -685,7 +688,6 @@ func Test_VRFV2PlusServiceFailsWhenVRFOwnerProvided(t *testing.T) { vuni.prm, vuni.legacyChains, logger.TestLogger(t), - cfg.Database(), mailMon) chain, err := vuni.legacyChains.Get(testutils.FixtureChainID.String()) require.NoError(t, err) diff --git a/core/services/vrf/v1/integration_test.go b/core/services/vrf/v1/integration_test.go index f68700a8af7..1d11615950b 100644 --- a/core/services/vrf/v1/integration_test.go +++ b/core/services/vrf/v1/integration_test.go @@ -45,6 +45,7 @@ func TestIntegration_VRF_JPV2(t *testing.T) { for _, tt := range tests { test := tt t.Run(test.name, func(t *testing.T) { + ctx := testutils.Context(t) config, _ := heavyweight.FullTestDBV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { c.EVM[0].GasEstimator.EIP1559DynamicFees = &test.eip1559 c.EVM[0].ChainID = (*ubig.Big)(testutils.SimulatedChainID) @@ -75,7 +76,7 @@ func TestIntegration_VRF_JPV2(t *testing.T) { } var runs []pipeline.Run gomega.NewWithT(t).Eventually(func() bool { - runs, err = app.PipelineORM().GetAllRuns() + runs, err = app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) // It possible that we send the test request // before the Job spawner has started the vrf services, which is fine @@ -128,6 +129,7 @@ func TestIntegration_VRF_JPV2(t *testing.T) { func TestIntegration_VRF_WithBHS(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) config, _ := heavyweight.FullTestDBV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { c.EVM[0].GasEstimator.EIP1559DynamicFees = ptr(true) c.EVM[0].BlockBackfillDepth = ptr[uint32](500) @@ -196,7 +198,7 @@ func TestIntegration_VRF_WithBHS(t *testing.T) { var runs []pipeline.Run gomega.NewWithT(t).Eventually(func() bool { - runs, err = app.PipelineORM().GetAllRuns() + runs, err = app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) cu.Backend.Commit() return len(runs) == 1 && runs[0].State == pipeline.RunStatusCompleted diff --git a/core/services/vrf/v1/listener_v1.go b/core/services/vrf/v1/listener_v1.go index c57265634e5..ddf5779deb0 100644 --- a/core/services/vrf/v1/listener_v1.go +++ b/core/services/vrf/v1/listener_v1.go @@ -17,6 +17,7 @@ import ( "github.com/theodesp/go-heaps/pairing" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" "github.com/smartcontractkit/chainlink-common/pkg/utils/mathutil" @@ -303,17 +304,16 @@ func (lsn *Listener) RunLogListener(unsubscribes []func(), minConfs uint32) { break } recovery.WrapRecover(lsn.L, func() { - lsn.handleLog(lb, minConfs) + ctx, cancel := lsn.ChStop.NewCtx() + defer cancel() + lsn.handleLog(ctx, lb, minConfs) }) } } } } -func (lsn *Listener) handleLog(lb log.Broadcast, minConfs uint32) { - ctx, cancel := lsn.ChStop.NewCtx() - defer cancel() - +func (lsn *Listener) handleLog(ctx context.Context, lb log.Broadcast, minConfs uint32) { lggr := lsn.L.With( "log", lb.String(), "decodedLog", lb.DecodedLog(), @@ -380,7 +380,7 @@ func (lsn *Listener) shouldProcessLog(ctx context.Context, lb log.Broadcast) boo } func (lsn *Listener) markLogAsConsumed(ctx context.Context, lb log.Broadcast) { - err := lsn.Chain.LogBroadcaster().MarkConsumed(ctx, lb) + err := lsn.Chain.LogBroadcaster().MarkConsumed(ctx, nil, lb) lsn.L.ErrorIf(err, fmt.Sprintf("Unable to mark log %v as consumed", lb.String())) } @@ -486,9 +486,10 @@ func (lsn *Listener) ProcessRequest(ctx context.Context, req request) bool { run := pipeline.NewRun(*lsn.Job.PipelineSpec, vars) // The VRF pipeline has no async tasks, so we don't need to check for `incomplete` - if _, err = lsn.PipelineRunner.Run(ctx, run, lggr, true, func(tx pg.Queryer) error { + if _, err = lsn.PipelineRunner.Run(ctx, run, lggr, true, func(tx sqlutil.DataSource) error { // Always mark consumed regardless of whether the proof failed or not. - if err = lsn.Chain.LogBroadcaster().MarkConsumed(ctx, req.lb); err != nil { + //TODO restore tx https://smartcontract-it.atlassian.net/browse/BCF-2978 + if err = lsn.Chain.LogBroadcaster().MarkConsumed(ctx, nil, req.lb); err != nil { lggr.Errorw("Failed mark consumed", "err", err) } return nil @@ -525,7 +526,7 @@ func (lsn *Listener) Close() error { }) } -func (lsn *Listener) HandleLog(lb log.Broadcast) { +func (lsn *Listener) HandleLog(ctx context.Context, lb log.Broadcast) { if !lsn.Deduper.ShouldDeliver(lb.RawLog()) { lsn.L.Tracew("skipping duplicate log broadcast", "log", lb.RawLog()) return diff --git a/core/services/vrf/v2/integration_helpers_test.go b/core/services/vrf/v2/integration_helpers_test.go index f19f39f03f2..3d7a94ae833 100644 --- a/core/services/vrf/v2/integration_helpers_test.go +++ b/core/services/vrf/v2/integration_helpers_test.go @@ -62,6 +62,7 @@ func testSingleConsumerHappyPath( rwfe v22.RandomWordsFulfilled, subID *big.Int), ) { + ctx := testutils.Context(t) key1 := cltest.MustGenerateRandomKey(t) key2 := cltest.MustGenerateRandomKey(t) gasLanePriceWei := assets.GWei(10) @@ -87,7 +88,7 @@ func testSingleConsumerHappyPath( // Fund gas lanes. sendEth(t, ownerKey, uni.backend, key1.Address, 10) sendEth(t, ownerKey, uni.backend, key2.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job using key1 and key2 on the same gas lane. jbs := createVRFJobs( @@ -111,7 +112,7 @@ func testSingleConsumerHappyPath( // Wait for fulfillment to be queued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 @@ -133,7 +134,7 @@ func testSingleConsumerHappyPath( requestID2, _ := requestRandomnessAndAssertRandomWordsRequestedEvent(t, consumerContract, consumer, keyHash, subID, numWords, 500_000, coordinator, uni.backend, nativePayment) gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 2 @@ -153,11 +154,11 @@ func testSingleConsumerHappyPath( assertNumRandomWords(t, consumerContract, numWords) // Assert that both send addresses were used to fulfill the requests - n, err := uni.backend.PendingNonceAt(testutils.Context(t), key1.Address) + n, err := uni.backend.PendingNonceAt(ctx, key1.Address) require.NoError(t, err) require.EqualValues(t, 1, n) - n, err = uni.backend.PendingNonceAt(testutils.Context(t), key2.Address) + n, err = uni.backend.PendingNonceAt(ctx, key2.Address) require.NoError(t, err) require.EqualValues(t, 1, n) @@ -182,6 +183,7 @@ func testMultipleConsumersNeedBHS( coordinator v22.CoordinatorV2_X, rwfe v22.RandomWordsFulfilled), ) { + ctx := testutils.Context(t) nConsumers := len(consumers) vrfKey := cltest.MustGenerateRandomKey(t) sendEth(t, ownerKey, uni.backend, vrfKey.Address, 10) @@ -216,7 +218,7 @@ func testMultipleConsumersNeedBHS( }) keys = append(keys, ownerKey, vrfKey) app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, uni.backend, keys...) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. vrfJobs := createVRFJobs( @@ -250,7 +252,7 @@ func testMultipleConsumersNeedBHS( // Ensure log poller is ready and has all logs. require.NoError(t, app.GetRelayers().LegacyEVMChains().Slice()[0].LogPoller().Ready()) - require.NoError(t, app.GetRelayers().LegacyEVMChains().Slice()[0].LogPoller().Replay(testutils.Context(t), 1)) + require.NoError(t, app.GetRelayers().LegacyEVMChains().Slice()[0].LogPoller().Replay(ctx, 1)) for i := 0; i < nConsumers; i++ { consumer := consumers[i] @@ -284,7 +286,7 @@ func testMultipleConsumersNeedBHS( // Wait for fulfillment to be queued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 @@ -320,6 +322,7 @@ func testMultipleConsumersNeedTrustedBHS( coordinator v22.CoordinatorV2_X, rwfe v22.RandomWordsFulfilled), ) { + ctx := testutils.Context(t) nConsumers := len(consumers) vrfKey := cltest.MustGenerateRandomKey(t) sendEth(t, ownerKey, uni.backend, vrfKey.Address, 10) @@ -364,7 +367,7 @@ func testMultipleConsumersNeedTrustedBHS( }) keys = append(keys, ownerKey, vrfKey) app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, uni.backend, keys...) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. vrfJobs := createVRFJobs( @@ -403,7 +406,7 @@ func testMultipleConsumersNeedTrustedBHS( // Ensure log poller is ready and has all logs. chain := app.GetRelayers().LegacyEVMChains().Slice()[0] require.NoError(t, chain.LogPoller().Ready()) - require.NoError(t, chain.LogPoller().Replay(testutils.Context(t), 1)) + require.NoError(t, chain.LogPoller().Replay(ctx, 1)) for i := 0; i < nConsumers; i++ { consumer := consumers[i] @@ -445,7 +448,7 @@ func testMultipleConsumersNeedTrustedBHS( // Wait for fulfillment to be queued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 @@ -534,6 +537,7 @@ func testSingleConsumerHappyPathBatchFulfillment( rwfe v22.RandomWordsFulfilled, subID *big.Int), ) { + ctx := testutils.Context(t) key1 := cltest.MustGenerateRandomKey(t) gasLanePriceWei := assets.GWei(10) config, db := heavyweight.FullTestDBV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { @@ -555,7 +559,7 @@ func testSingleConsumerHappyPathBatchFulfillment( // Fund gas lane. sendEth(t, ownerKey, uni.backend, key1.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job using key1 and key2 on the same gas lane. jbs := createVRFJobs( @@ -590,7 +594,7 @@ func testSingleConsumerHappyPathBatchFulfillment( // Wait for fulfillment to be queued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) if bigGasCallback { @@ -640,6 +644,7 @@ func testSingleConsumerNeedsTopUp( coordinator v22.CoordinatorV2_X, rwfe v22.RandomWordsFulfilled), ) { + ctx := testutils.Context(t) key := cltest.MustGenerateRandomKey(t) gasLanePriceWei := assets.GWei(1000) config, db := heavyweight.FullTestDBV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { @@ -659,7 +664,7 @@ func testSingleConsumerNeedsTopUp( // Fund expensive gas lane. sendEth(t, ownerKey, uni.backend, key.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. jbs := createVRFJobs( @@ -682,7 +687,7 @@ func testSingleConsumerNeedsTopUp( // Fulfillment will not be enqueued because subscriber doesn't have enough LINK. gomega.NewGomegaWithT(t).Consistently(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("assert 1", "runs", len(runs)) return len(runs) == 0 @@ -695,7 +700,7 @@ func testSingleConsumerNeedsTopUp( // Wait for fulfillment to go through. gomega.NewWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("assert 2", "runs", len(runs)) return len(runs) == 1 @@ -737,6 +742,7 @@ func testBlockHeaderFeeder( coordinator v22.CoordinatorV2_X, rwfe v22.RandomWordsFulfilled), ) { + ctx := testutils.Context(t) nConsumers := len(consumers) vrfKey := cltest.MustGenerateRandomKey(t) @@ -760,7 +766,7 @@ func testBlockHeaderFeeder( c.EVM[0].FinalityDepth = ptr[uint32](2) }) app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, uni.backend, ownerKey, vrfKey, bhfKey) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. vrfJobs := createVRFJobs( @@ -792,7 +798,7 @@ func testBlockHeaderFeeder( // Ensure log poller is ready and has all logs. require.NoError(t, app.GetRelayers().LegacyEVMChains().Slice()[0].LogPoller().Ready()) - require.NoError(t, app.GetRelayers().LegacyEVMChains().Slice()[0].LogPoller().Replay(testutils.Context(t), 1)) + require.NoError(t, app.GetRelayers().LegacyEVMChains().Slice()[0].LogPoller().Replay(ctx, 1)) for i := 0; i < nConsumers; i++ { consumer := consumers[i] @@ -821,7 +827,7 @@ func testBlockHeaderFeeder( // Wait for fulfillment to be queued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 @@ -900,6 +906,7 @@ func testSingleConsumerForcedFulfillment( batchEnabled bool, vrfVersion vrfcommon.Version, ) { + ctx := testutils.Context(t) key1 := cltest.MustGenerateRandomKey(t) key2 := cltest.MustGenerateRandomKey(t) gasLanePriceWei := assets.GWei(10) @@ -951,7 +958,7 @@ func testSingleConsumerForcedFulfillment( // Fund gas lanes. sendEth(t, ownerKey, uni.backend, key1.Address, 10) sendEth(t, ownerKey, uni.backend, key2.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job using key1 and key2 on the same gas lane. jbs := createVRFJobs( @@ -1065,6 +1072,7 @@ func testSingleConsumerEIP150( vrfVersion vrfcommon.Version, nativePayment bool, ) { + ctx := testutils.Context(t) callBackGasLimit := int64(2_500_000) // base callback gas. key1 := cltest.MustGenerateRandomKey(t) @@ -1090,7 +1098,7 @@ func testSingleConsumerEIP150( // Fund gas lane. sendEth(t, ownerKey, uni.backend, key1.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. jbs := createVRFJobs( @@ -1114,7 +1122,7 @@ func testSingleConsumerEIP150( // Wait for simulation to pass. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 @@ -1132,6 +1140,7 @@ func testSingleConsumerEIP150Revert( vrfVersion vrfcommon.Version, nativePayment bool, ) { + ctx := testutils.Context(t) callBackGasLimit := int64(2_500_000) // base callback gas. eip150Fee := int64(0) // no premium given for callWithExactGas coordinatorFulfillmentOverhead := int64(90_000) // fixed gas used in coordinator fulfillment @@ -1160,7 +1169,7 @@ func testSingleConsumerEIP150Revert( // Fund gas lane. sendEth(t, ownerKey, uni.backend, key1.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. jbs := createVRFJobs( @@ -1184,7 +1193,7 @@ func testSingleConsumerEIP150Revert( // Simulation should not pass. gomega.NewGomegaWithT(t).Consistently(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 0 @@ -1202,6 +1211,7 @@ func testSingleConsumerBigGasCallbackSandwich( vrfVersion vrfcommon.Version, nativePayment bool, ) { + ctx := testutils.Context(t) key1 := cltest.MustGenerateRandomKey(t) gasLanePriceWei := assets.GWei(100) config, db := heavyweight.FullTestDBV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { @@ -1224,7 +1234,7 @@ func testSingleConsumerBigGasCallbackSandwich( // Fund gas lane. sendEth(t, ownerKey, uni.backend, key1.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. jbs := createVRFJobs( @@ -1253,7 +1263,7 @@ func testSingleConsumerBigGasCallbackSandwich( // Assert that we've completed 0 runs before adding 3 new requests. { - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) assert.Equal(t, 0, len(runs)) assert.Equal(t, 3, len(reqIDs)) @@ -1262,7 +1272,7 @@ func testSingleConsumerBigGasCallbackSandwich( // Wait for the 50_000 gas randomness request to be enqueued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 @@ -1271,7 +1281,7 @@ func testSingleConsumerBigGasCallbackSandwich( // After the first successful request, no more will be enqueued. gomega.NewGomegaWithT(t).Consistently(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("assert 1", "runs", len(runs)) return len(runs) == 1 @@ -1285,7 +1295,7 @@ func testSingleConsumerBigGasCallbackSandwich( // Assert that we've still only completed 1 run before adding new requests. { - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) assert.Equal(t, 1, len(runs)) } @@ -1300,7 +1310,7 @@ func testSingleConsumerBigGasCallbackSandwich( // Fulfillment will not be enqueued because subscriber doesn't have enough LINK for any of the requests. gomega.NewGomegaWithT(t).Consistently(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("assert 1", "runs", len(runs)) return len(runs) == 1 @@ -1318,6 +1328,7 @@ func testSingleConsumerMultipleGasLanes( vrfVersion vrfcommon.Version, nativePayment bool, ) { + ctx := testutils.Context(t) cheapKey := cltest.MustGenerateRandomKey(t) expensiveKey := cltest.MustGenerateRandomKey(t) cheapGasLane := assets.GWei(10) @@ -1349,7 +1360,7 @@ func testSingleConsumerMultipleGasLanes( // Fund gas lanes. sendEth(t, ownerKey, uni.backend, cheapKey.Address, 10) sendEth(t, ownerKey, uni.backend, expensiveKey.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF jobs. jbs := createVRFJobs( @@ -1374,7 +1385,7 @@ func testSingleConsumerMultipleGasLanes( // Wait for fulfillment to be queued for cheap key hash. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("assert 1", "runs", len(runs)) return len(runs) == 1 @@ -1394,7 +1405,7 @@ func testSingleConsumerMultipleGasLanes( // We should not have any new fulfillments until a top up. gomega.NewWithT(t).Consistently(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("assert 2", "runs", len(runs)) return len(runs) == 1 @@ -1406,7 +1417,7 @@ func testSingleConsumerMultipleGasLanes( // Wait for fulfillment to be queued for expensive key hash. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("assert 1", "runs", len(runs)) return len(runs) == 2 @@ -1442,6 +1453,7 @@ func testSingleConsumerAlwaysRevertingCallbackStillFulfilled( vrfVersion vrfcommon.Version, nativePayment bool, ) { + ctx := testutils.Context(t) key := cltest.MustGenerateRandomKey(t) gasLanePriceWei := assets.GWei(10) config, db := heavyweight.FullTestDBV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { @@ -1464,7 +1476,7 @@ func testSingleConsumerAlwaysRevertingCallbackStillFulfilled( // Fund gas lane. sendEth(t, ownerKey, uni.backend, key.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. jbs := createVRFJobs( @@ -1488,7 +1500,7 @@ func testSingleConsumerAlwaysRevertingCallbackStillFulfilled( // Wait for fulfillment to be queued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 @@ -1511,6 +1523,7 @@ func testConsumerProxyHappyPath( vrfVersion vrfcommon.Version, nativePayment bool, ) { + ctx := testutils.Context(t) key1 := cltest.MustGenerateRandomKey(t) key2 := cltest.MustGenerateRandomKey(t) gasLanePriceWei := assets.GWei(10) @@ -1540,7 +1553,7 @@ func testConsumerProxyHappyPath( // Create gas lane. sendEth(t, ownerKey, uni.backend, key1.Address, 10) sendEth(t, ownerKey, uni.backend, key2.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job using key1 and key2 on the same gas lane. jbs := createVRFJobs( @@ -1565,7 +1578,7 @@ func testConsumerProxyHappyPath( // Wait for fulfillment to be queued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 @@ -1591,7 +1604,7 @@ func testConsumerProxyHappyPath( t, consumerContract, consumerOwner, keyHash, subID, numWords, 750_000, uni.rootContract, uni.backend, nativePayment) gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 2 @@ -1603,11 +1616,11 @@ func testConsumerProxyHappyPath( assertNumRandomWords(t, consumerContract, numWords) // Assert that both send addresses were used to fulfill the requests - n, err := uni.backend.PendingNonceAt(testutils.Context(t), key1.Address) + n, err := uni.backend.PendingNonceAt(ctx, key1.Address) require.NoError(t, err) require.EqualValues(t, 1, n) - n, err = uni.backend.PendingNonceAt(testutils.Context(t), key2.Address) + n, err = uni.backend.PendingNonceAt(ctx, key2.Address) require.NoError(t, err) require.EqualValues(t, 1, n) @@ -1644,6 +1657,7 @@ func testMaliciousConsumer( batchEnabled bool, vrfVersion vrfcommon.Version, ) { + ctx := testutils.Context(t) config, _ := heavyweight.FullTestDBV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { c.EVM[0].GasEstimator.LimitDefault = ptr[uint64](2_000_000) c.EVM[0].GasEstimator.PriceMax = assets.GWei(1) @@ -1656,7 +1670,7 @@ func testMaliciousConsumer( carol := uni.vrfConsumers[0] app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, uni.backend, ownerKey) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) err := app.GetKeyStore().Unlock(cltest.Password) require.NoError(t, err) @@ -1702,7 +1716,7 @@ func testMaliciousConsumer( // by the node. var attempts []txmgr.TxAttempt gomega.NewWithT(t).Eventually(func() bool { - attempts, _, err = app.TxmStorageService().TxAttempts(testutils.Context(t), 0, 1000) + attempts, _, err = app.TxmStorageService().TxAttempts(ctx, 0, 1000) require.NoError(t, err) // It possible that we send the test request // before the job spawner has started the vrf services, which is fine @@ -1716,7 +1730,7 @@ func testMaliciousConsumer( // The fulfillment tx should succeed ch, err := app.GetRelayers().LegacyEVMChains().Get(evmtest.MustGetDefaultChainID(t, config.EVMConfigs()).String()) require.NoError(t, err) - r, err := ch.Client().TransactionReceipt(testutils.Context(t), attempts[0].Hash) + r, err := ch.Client().TransactionReceipt(ctx, attempts[0].Hash) require.NoError(t, err) require.Equal(t, uint64(1), r.Status) @@ -1759,6 +1773,7 @@ func testReplayOldRequestsOnStartUp( rwfe v22.RandomWordsFulfilled, subID *big.Int), ) { + ctx := testutils.Context(t) sendingKey := cltest.MustGenerateRandomKey(t) gasLanePriceWei := assets.GWei(10) config, _ := heavyweight.FullTestDBV2(t, func(c *chainlink.Config, s *chainlink.Secrets) { @@ -1778,7 +1793,7 @@ func testReplayOldRequestsOnStartUp( // Fund gas lanes. sendEth(t, ownerKey, uni.backend, sendingKey.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF Key, register it to coordinator and export vrfkey, err := app.GetKeyStore().VRF().Create() @@ -1816,7 +1831,7 @@ func testReplayOldRequestsOnStartUp( // Start a new app and create VRF job using the same VRF key created above app = cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, uni.backend, ownerKey, sendingKey) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) vrfKey, err := app.GetKeyStore().VRF().Import(encodedVrfKey, testutils.Password) require.NoError(t, err) @@ -1863,7 +1878,7 @@ func testReplayOldRequestsOnStartUp( // Wait for fulfillment to be queued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 diff --git a/core/services/vrf/v2/integration_v2_plus_test.go b/core/services/vrf/v2/integration_v2_plus_test.go index bfec76afec3..742ff99071c 100644 --- a/core/services/vrf/v2/integration_v2_plus_test.go +++ b/core/services/vrf/v2/integration_v2_plus_test.go @@ -1141,6 +1141,7 @@ func setupSubscriptionAndFund( func TestVRFV2PlusIntegration_Migration(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) ownerKey := cltest.MustGenerateRandomKey(t) uni := newVRFCoordinatorV2PlusUniverse(t, ownerKey, 1, false) key1 := cltest.MustGenerateRandomKey(t) @@ -1200,7 +1201,7 @@ func TestVRFV2PlusIntegration_Migration(t *testing.T) { // Wait for fulfillment to be queued. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 1 diff --git a/core/services/vrf/v2/integration_v2_test.go b/core/services/vrf/v2/integration_v2_test.go index 1a7c15a2508..0c81c3faca5 100644 --- a/core/services/vrf/v2/integration_v2_test.go +++ b/core/services/vrf/v2/integration_v2_test.go @@ -73,7 +73,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/keystore" "github.com/smartcontractkit/chainlink/v2/core/services/keystore/keys/ethkey" "github.com/smartcontractkit/chainlink/v2/core/services/keystore/keys/vrfkey" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" evmrelay "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm" "github.com/smartcontractkit/chainlink/v2/core/services/signatures/secp256k1" @@ -466,7 +465,8 @@ func deployOldCoordinator( // Send eth from prefunded account. // Amount is number of ETH not wei. func sendEth(t *testing.T, key ethkey.KeyV2, ec *backends.SimulatedBackend, to common.Address, eth int) { - nonce, err := ec.PendingNonceAt(testutils.Context(t), key.Address) + ctx := testutils.Context(t) + nonce, err := ec.PendingNonceAt(ctx, key.Address) require.NoError(t, err) tx := gethtypes.NewTx(&gethtypes.DynamicFeeTx{ ChainID: testutils.SimulatedChainID, @@ -480,7 +480,7 @@ func sendEth(t *testing.T, key ethkey.KeyV2, ec *backends.SimulatedBackend, to c }) signedTx, err := gethtypes.SignTx(tx, gethtypes.NewLondonSigner(testutils.SimulatedChainID), key.ToEcdsaPrivKey()) require.NoError(t, err) - err = ec.SendTransaction(testutils.Context(t), signedTx) + err = ec.SendTransaction(ctx, signedTx) require.NoError(t, err) ec.Commit() } @@ -996,7 +996,9 @@ func testEoa( batchingEnabled bool, batchCoordinatorAddress common.Address, vrfOwnerAddress *common.Address, - vrfVersion vrfcommon.Version) { + vrfVersion vrfcommon.Version, +) { + ctx := testutils.Context(t) gasLimit := int64(2_500_000) finalityDepth := uint32(50) @@ -1030,7 +1032,7 @@ func testEoa( // Fund gas lane. sendEth(t, ownerKey, uni.backend, key1.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. jbs := createVRFJobs( @@ -1059,7 +1061,7 @@ func testEoa( // Ensure request is not fulfilled. gomega.NewGomegaWithT(t).Consistently(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 0 @@ -1069,10 +1071,9 @@ func testEoa( var broadcastsBeforeFinality []evmlogger.LogBroadcast var broadcastsAfterFinality []evmlogger.LogBroadcast query := `SELECT block_hash, consumed, log_index, job_id FROM log_broadcasts` - q := pg.NewQ(app.GetSqlxDB(), app.Logger, app.Config.Database()) // Execute the query. - require.NoError(t, q.Select(&broadcastsBeforeFinality, query)) + require.NoError(t, app.GetDB().SelectContext(ctx, &broadcastsBeforeFinality, query)) // Ensure there is only one log broadcast (our EOA request), and that // it hasn't been marked as consumed yet. @@ -1087,14 +1088,14 @@ func testEoa( // Ensure the request is still not fulfilled. gomega.NewGomegaWithT(t).Consistently(func() bool { uni.backend.Commit() - runs, err := app.PipelineORM().GetAllRuns() + runs, err := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) t.Log("runs", len(runs)) return len(runs) == 0 }, 5*time.Second, time.Second).Should(gomega.BeTrue()) // Execute the query for log broadcasts again after finality depth has elapsed. - require.NoError(t, q.Select(&broadcastsAfterFinality, query)) + require.NoError(t, app.GetDB().SelectContext(ctx, &broadcastsAfterFinality, query)) // Ensure that there is still only one log broadcast (our EOA request), but that // it has been marked as "consumed," such that it won't be retried. @@ -1158,6 +1159,7 @@ func deployWrapper(t *testing.T, uni coordinatorV2UniverseCommon, wrapperOverhea func TestVRFV2Integration_SingleConsumer_Wrapper(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) wrapperOverhead := uint32(30_000) coordinatorOverhead := uint32(90_000) @@ -1179,7 +1181,7 @@ func TestVRFV2Integration_SingleConsumer_Wrapper(t *testing.T) { // Fund gas lane. sendEth(t, ownerKey, uni.backend, key1.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. jbs := createVRFJobs( @@ -1221,7 +1223,7 @@ func TestVRFV2Integration_SingleConsumer_Wrapper(t *testing.T) { // Wait for simulation to pass. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err2 := app.PipelineORM().GetAllRuns() + runs, err2 := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err2) t.Log("runs", len(runs)) return len(runs) == 1 @@ -1238,6 +1240,7 @@ func TestVRFV2Integration_SingleConsumer_Wrapper(t *testing.T) { func TestVRFV2Integration_Wrapper_High_Gas(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) wrapperOverhead := uint32(30_000) coordinatorOverhead := uint32(90_000) @@ -1261,7 +1264,7 @@ func TestVRFV2Integration_Wrapper_High_Gas(t *testing.T) { // Fund gas lane. sendEth(t, ownerKey, uni.backend, key1.Address, 10) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) // Create VRF job. jbs := createVRFJobs( @@ -1303,7 +1306,7 @@ func TestVRFV2Integration_Wrapper_High_Gas(t *testing.T) { // Wait for simulation to pass. gomega.NewGomegaWithT(t).Eventually(func() bool { uni.backend.Commit() - runs, err2 := app.PipelineORM().GetAllRuns() + runs, err2 := app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err2) t.Log("runs", len(runs)) return len(runs) == 1 @@ -1631,6 +1634,7 @@ func TestSimpleConsumerExample(t *testing.T) { func TestIntegrationVRFV2(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) // Reconfigure the sim chain with a default gas price of 1 gwei, // max gas limit of 2M and a key specific max 10 gwei price. // Keep the prices low so we can operate with small link balance subscriptions. @@ -1650,11 +1654,11 @@ func TestIntegrationVRFV2(t *testing.T) { carolContractAddress := uni.consumerContractAddresses[0] app := cltest.NewApplicationWithConfigV2AndKeyOnSimulatedBlockchain(t, config, uni.backend, key) - keys, err := app.KeyStore.Eth().EnabledKeysForChain(testutils.Context(t), testutils.SimulatedChainID) + keys, err := app.KeyStore.Eth().EnabledKeysForChain(ctx, testutils.SimulatedChainID) require.NoError(t, err) require.Zero(t, key.Cmp(keys[0])) - require.NoError(t, app.Start(testutils.Context(t))) + require.NoError(t, app.Start(ctx)) var chain legacyevm.Chain chain, err = app.GetRelayers().LegacyEVMChains().Get(testutils.SimulatedChainID.String()) require.NoError(t, err) @@ -1723,7 +1727,7 @@ func TestIntegrationVRFV2(t *testing.T) { // by the node. var runs []pipeline.Run gomega.NewWithT(t).Eventually(func() bool { - runs, err = app.PipelineORM().GetAllRuns() + runs, err = app.PipelineORM().GetAllRuns(ctx) require.NoError(t, err) // It is possible that we send the test request // before the job spawner has started the vrf services, which is fine @@ -1745,7 +1749,7 @@ func TestIntegrationVRFV2(t *testing.T) { return len(rf) == 1 }, testutils.WaitTimeout(t), 500*time.Millisecond).Should(gomega.BeTrue()) assert.True(t, rf[0].Success(), "expected callback to succeed") - fulfillReceipt, err := uni.backend.TransactionReceipt(testutils.Context(t), rf[0].Raw().TxHash) + fulfillReceipt, err := uni.backend.TransactionReceipt(ctx, rf[0].Raw().TxHash) require.NoError(t, err) // Assert all the random words received by the consumer are different and non-zero. @@ -1813,7 +1817,7 @@ func TestIntegrationVRFV2(t *testing.T) { // We should see the response count present require.NoError(t, err) var counts map[string]uint64 - counts, err = listenerV2.GetStartingResponseCountsV2(testutils.Context(t)) + counts, err = listenerV2.GetStartingResponseCountsV2(ctx) require.NoError(t, err) t.Log(counts, rf[0].RequestID().String()) assert.Equal(t, uint64(1), counts[rf[0].RequestID().String()]) diff --git a/core/services/vrf/v2/listener_v2.go b/core/services/vrf/v2/listener_v2.go index 71c6e72a06f..e820cff63b7 100644 --- a/core/services/vrf/v2/listener_v2.go +++ b/core/services/vrf/v2/listener_v2.go @@ -14,6 +14,7 @@ import ( "github.com/theodesp/go-heaps/pairing" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" txmgrcommon "github.com/smartcontractkit/chainlink/v2/common/txmgr" txmgrtypes "github.com/smartcontractkit/chainlink/v2/common/txmgr/types" @@ -29,7 +30,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/services/vrf/vrfcommon" ) @@ -70,7 +70,7 @@ func New( l logger.Logger, chain legacyevm.Chain, chainID *big.Int, - q pg.Q, + ds sqlutil.DataSource, coordinator CoordinatorV2_X, batchCoordinator batch_vrf_coordinator_v2.BatchVRFCoordinatorV2Interface, vrfOwner vrf_owner.VRFOwnerInterface, @@ -93,7 +93,7 @@ func New( vrfOwner: vrfOwner, pipelineRunner: pipelineRunner, job: job, - q: q, + ds: ds, gethks: gethks, chStop: make(chan struct{}), reqAdded: reqAdded, @@ -120,7 +120,7 @@ type listenerV2 struct { pipelineRunner pipeline.Runner job job.Job - q pg.Q + ds sqlutil.DataSource gethks keystore.Eth chStop services.StopChan diff --git a/core/services/vrf/v2/listener_v2_log_processor.go b/core/services/vrf/v2/listener_v2_log_processor.go index db84fb47e3e..673f8618c0b 100644 --- a/core/services/vrf/v2/listener_v2_log_processor.go +++ b/core/services/vrf/v2/listener_v2_log_processor.go @@ -20,6 +20,7 @@ import ( "github.com/pkg/errors" "go.uber.org/multierr" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/hex" txmgrcommon "github.com/smartcontractkit/chainlink/v2/common/txmgr" txmgrtypes "github.com/smartcontractkit/chainlink/v2/common/txmgr/types" @@ -28,7 +29,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated/vrf_coordinator_v2" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated/vrf_coordinator_v2plus_interface" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/services/vrf/vrfcommon" "github.com/smartcontractkit/chainlink/v2/core/utils" @@ -565,55 +565,53 @@ func (lsn *listenerV2) enqueueForceFulfillment( } // fulfill the request through the VRF owner - err = lsn.q.Transaction(func(tx pg.Queryer) error { - lsn.l.Infow("VRFOwner.fulfillRandomWords vs. VRFCoordinatorV2.fulfillRandomWords", - "vrf_owner.fulfillRandomWords", hexutil.Encode(vrfOwnerABI.Methods["fulfillRandomWords"].ID), - "vrf_coordinator_v2.fulfillRandomWords", hexutil.Encode(coordinatorV2ABI.Methods["fulfillRandomWords"].ID), - ) + lsn.l.Infow("VRFOwner.fulfillRandomWords vs. VRFCoordinatorV2.fulfillRandomWords", + "vrf_owner.fulfillRandomWords", hexutil.Encode(vrfOwnerABI.Methods["fulfillRandomWords"].ID), + "vrf_coordinator_v2.fulfillRandomWords", hexutil.Encode(coordinatorV2ABI.Methods["fulfillRandomWords"].ID), + ) - vrfOwnerAddress1 := lsn.vrfOwner.Address() - vrfOwnerAddressSpec := lsn.job.VRFSpec.VRFOwnerAddress.Address() - lsn.l.Infow("addresses diff", "wrapper_address", vrfOwnerAddress1, "spec_address", vrfOwnerAddressSpec) + vrfOwnerAddress1 := lsn.vrfOwner.Address() + vrfOwnerAddressSpec := lsn.job.VRFSpec.VRFOwnerAddress.Address() + lsn.l.Infow("addresses diff", "wrapper_address", vrfOwnerAddress1, "spec_address", vrfOwnerAddressSpec) - lsn.l.Infow("fulfillRandomWords payload", "proof", p.proof, "commitment", p.reqCommitment.Get(), "payload", p.payload) - txData := hexutil.MustDecode(p.payload) - if err != nil { - return fmt.Errorf("abi pack VRFOwner.fulfillRandomWords: %w", err) - } - estimateGasLimit, err := lsn.chain.Client().EstimateGas(ctx, ethereum.CallMsg{ - From: fromAddress, - To: &vrfOwnerAddressSpec, - Data: txData, - }) - if err != nil { - return fmt.Errorf("failed to estimate gas on VRFOwner.fulfillRandomWords: %w", err) - } + lsn.l.Infow("fulfillRandomWords payload", "proof", p.proof, "commitment", p.reqCommitment.Get(), "payload", p.payload) + txData := hexutil.MustDecode(p.payload) + if err != nil { + err = fmt.Errorf("abi pack VRFOwner.fulfillRandomWords: %w", err) + return + } + estimateGasLimit, err := lsn.chain.Client().EstimateGas(ctx, ethereum.CallMsg{ + From: fromAddress, + To: &vrfOwnerAddressSpec, + Data: txData, + }) + if err != nil { + err = fmt.Errorf("failed to estimate gas on VRFOwner.fulfillRandomWords: %w", err) + return + } - lsn.l.Infow("Estimated gas limit on force fulfillment", - "estimateGasLimit", estimateGasLimit, "pipelineGasLimit", p.gasLimit) - if estimateGasLimit < p.gasLimit { - estimateGasLimit = p.gasLimit - } + lsn.l.Infow("Estimated gas limit on force fulfillment", + "estimateGasLimit", estimateGasLimit, "pipelineGasLimit", p.gasLimit) + if estimateGasLimit < p.gasLimit { + estimateGasLimit = p.gasLimit + } - requestID := common.BytesToHash(p.req.req.RequestID().Bytes()) - subID := p.req.req.SubID() - requestTxHash := p.req.req.Raw().TxHash - etx, err = lsn.chain.TxManager().CreateTransaction(ctx, txmgr.TxRequest{ - FromAddress: fromAddress, - ToAddress: lsn.vrfOwner.Address(), - EncodedPayload: txData, - FeeLimit: estimateGasLimit, - Strategy: txmgrcommon.NewSendEveryStrategy(), - Meta: &txmgr.TxMeta{ - RequestID: &requestID, - SubID: ptr(subID.Uint64()), - RequestTxHash: &requestTxHash, - // No max link since simulation failed - }, - }) - return err + requestID := common.BytesToHash(p.req.req.RequestID().Bytes()) + subID := p.req.req.SubID() + requestTxHash := p.req.req.Raw().TxHash + return lsn.chain.TxManager().CreateTransaction(ctx, txmgr.TxRequest{ + FromAddress: fromAddress, + ToAddress: lsn.vrfOwner.Address(), + EncodedPayload: txData, + FeeLimit: estimateGasLimit, + Strategy: txmgrcommon.NewSendEveryStrategy(), + Meta: &txmgr.TxMeta{ + RequestID: &requestID, + SubID: ptr(subID.Uint64()), + RequestTxHash: &requestTxHash, + // No max link since simulation failed + }, }) - return } // For an errored pipeline run, wait until the finality depth of the chain to have elapsed, @@ -786,8 +784,8 @@ func (lsn *listenerV2) processRequestsPerSubHelper( ll.Infow("Enqueuing fulfillment") var transaction txmgr.Tx - err = lsn.q.Transaction(func(tx pg.Queryer) error { - if err = lsn.pipelineRunner.InsertFinishedRun(p.run, true, pg.WithQueryer(tx)); err != nil { + err = sqlutil.TransactDataSource(ctx, lsn.ds, nil, func(tx sqlutil.DataSource) error { + if err = lsn.pipelineRunner.InsertFinishedRun(ctx, tx, p.run, true); err != nil { return err } diff --git a/core/services/vrf/v2/listener_v2_types.go b/core/services/vrf/v2/listener_v2_types.go index f10297f31a9..c7dc45bb3bd 100644 --- a/core/services/vrf/v2/listener_v2_types.go +++ b/core/services/vrf/v2/listener_v2_types.go @@ -8,10 +8,10 @@ import ( "github.com/ethereum/go-ethereum/common" heaps "github.com/theodesp/go-heaps" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" txmgrcommon "github.com/smartcontractkit/chainlink/v2/common/txmgr" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" "github.com/smartcontractkit/chainlink/v2/core/logger" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" "github.com/smartcontractkit/chainlink/v2/core/services/vrf/vrfcommon" ) @@ -222,8 +222,8 @@ func (lsn *listenerV2) processBatch( ) ll.Info("Enqueuing batch fulfillment") var ethTX txmgr.Tx - err = lsn.q.Transaction(func(tx pg.Queryer) error { - if err = lsn.pipelineRunner.InsertFinishedRuns(batch.runs, true, pg.WithQueryer(tx)); err != nil { + err = sqlutil.TransactDataSource(ctx, lsn.ds, nil, func(tx sqlutil.DataSource) error { + if err = lsn.pipelineRunner.InsertFinishedRuns(ctx, tx, batch.runs, true); err != nil { return fmt.Errorf("inserting finished pipeline runs: %w", err) } diff --git a/core/services/vrf/v2/reverted_txns.go b/core/services/vrf/v2/reverted_txns.go index d2f62fbf271..cfd9954a208 100644 --- a/core/services/vrf/v2/reverted_txns.go +++ b/core/services/vrf/v2/reverted_txns.go @@ -17,13 +17,13 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/pkg/errors" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" txmgrcommon "github.com/smartcontractkit/chainlink/v2/common/txmgr" evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" evmutils "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils" "github.com/smartcontractkit/chainlink/v2/core/gethwrappers/generated/vrf_coordinator_v2" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/utils" ) @@ -71,15 +71,15 @@ func (lsn *listenerV2) handleRevertedTxns(ctx context.Context, pollPeriod time.D lsn.l.Infow("Handling reverted txns") // Fetch recent single and batch txns, that have not been force-fulfilled - recentSingleTxns, err := lsn.fetchRecentSingleTxns(ctx, lsn.q, lsn.chainID.Uint64(), pollPeriod) + recentSingleTxns, err := lsn.fetchRecentSingleTxns(ctx, lsn.ds, lsn.chainID.Uint64(), pollPeriod) if err != nil { lsn.l.Fatalw("Fetch recent txns", "err", err) } - recentBatchTxns, err := lsn.fetchRecentBatchTxns(ctx, lsn.q, lsn.chainID.Uint64(), pollPeriod) + recentBatchTxns, err := lsn.fetchRecentBatchTxns(ctx, lsn.ds, lsn.chainID.Uint64(), pollPeriod) if err != nil { lsn.l.Fatalw("Fetch recent batch txns", "err", err) } - recentForceFulfillmentTxns, err := lsn.fetchRevertedForceFulfilmentTxns(ctx, lsn.q, lsn.chainID.Uint64(), pollPeriod) + recentForceFulfillmentTxns, err := lsn.fetchRevertedForceFulfilmentTxns(ctx, lsn.ds, lsn.chainID.Uint64(), pollPeriod) if err != nil { lsn.l.Fatalw("Fetch recent reverted force-fulfillment txns", "err", err) } @@ -108,7 +108,7 @@ func (lsn *listenerV2) handleRevertedTxns(ctx context.Context, pollPeriod time.D } func (lsn *listenerV2) fetchRecentSingleTxns(ctx context.Context, - q pg.Q, + ds sqlutil.DataSource, chainID uint64, pollPeriod time.Duration) ([]TxnReceiptDB, error) { @@ -155,7 +155,7 @@ func (lsn *listenerV2) fetchRecentSingleTxns(ctx context.Context, var recentReceipts []TxnReceiptDB before := time.Now() - err := q.Select(&recentReceipts, sqlQuery, chainID) + err := ds.SelectContext(ctx, &recentReceipts, sqlQuery, chainID) lsn.postSqlLog(ctx, before, pollPeriod, "FetchRecentSingleTxns") if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, errors.Wrap(err, "Error fetching recent non-force-fulfilled txns") @@ -172,7 +172,7 @@ func (lsn *listenerV2) fetchRecentSingleTxns(ctx context.Context, } func (lsn *listenerV2) fetchRecentBatchTxns(ctx context.Context, - q pg.Q, + ds sqlutil.DataSource, chainID uint64, pollPeriod time.Duration) ([]TxnReceiptDB, error) { sqlQuery := fmt.Sprintf(` @@ -217,7 +217,7 @@ func (lsn *listenerV2) fetchRecentBatchTxns(ctx context.Context, var recentReceipts []TxnReceiptDB before := time.Now() - err := q.Select(&recentReceipts, sqlQuery, chainID) + err := ds.SelectContext(ctx, &recentReceipts, sqlQuery, chainID) lsn.postSqlLog(ctx, before, pollPeriod, "FetchRecentBatchTxns") if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, errors.Wrap(err, "Error fetching recent non-force-fulfilled txns") @@ -231,7 +231,7 @@ func (lsn *listenerV2) fetchRecentBatchTxns(ctx context.Context, } func (lsn *listenerV2) fetchRevertedForceFulfilmentTxns(ctx context.Context, - q pg.Q, + ds sqlutil.DataSource, chainID uint64, pollPeriod time.Duration) ([]TxnReceiptDB, error) { @@ -271,7 +271,7 @@ func (lsn *listenerV2) fetchRevertedForceFulfilmentTxns(ctx context.Context, var recentReceipts []TxnReceiptDB before := time.Now() - err := q.Select(&recentReceipts, sqlQuery, chainID) + err := ds.SelectContext(ctx, &recentReceipts, sqlQuery, chainID) lsn.postSqlLog(ctx, before, pollPeriod, "FetchRevertedForceFulfilmentTxns") if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, errors.Wrap(err, "Error fetching recent reverted force-fulfilled txns") @@ -300,7 +300,7 @@ func (lsn *listenerV2) fetchRevertedForceFulfilmentTxns(ctx context.Context, `, ReqScanTimeRangeInDB) var allReceipts []TxnReceiptDB before = time.Now() - err = q.Select(&allReceipts, sqlQueryAll, chainID) + err = ds.SelectContext(ctx, &allReceipts, sqlQueryAll, chainID) lsn.postSqlLog(ctx, before, pollPeriod, "Fetch all ForceFulfilment Txns") if err != nil && !errors.Is(err, sql.ErrNoRows) { return nil, errors.Wrap(err, "Error fetching all recent force-fulfilled txns") @@ -389,9 +389,10 @@ func (lsn *listenerV2) postSqlLog(ctx context.Context, begin time.Time, pollPeri lsn.l.Debugw("SQL context canceled", "ms", elapsed.Milliseconds(), "err", ctx.Err(), "sql", queryName) } - timeout := lsn.q.QueryTimeout - if timeout <= 0 { - timeout = pollPeriod + timeout := pollPeriod + deadline, ok := ctx.Deadline() + if ok { + timeout = deadline.Sub(begin) } pct := float64(elapsed) / float64(timeout) diff --git a/core/services/webhook/delegate.go b/core/services/webhook/delegate.go index 0c08e992f32..690ae38d088 100644 --- a/core/services/webhook/delegate.go +++ b/core/services/webhook/delegate.go @@ -13,7 +13,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" "github.com/smartcontractkit/chainlink/v2/core/services/pipeline" ) @@ -74,7 +73,7 @@ func (d *Delegate) BeforeJobDeleted(spec job.Job) { ) } } -func (d *Delegate) OnDeleteJob(ctx context.Context, jb job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // ServicesForSpec satisfies the job.Delegate interface. func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) ([]job.ServiceCtx, error) { diff --git a/core/services/workflows/delegate.go b/core/services/workflows/delegate.go index 6db39d52dd6..dedf53e369b 100644 --- a/core/services/workflows/delegate.go +++ b/core/services/workflows/delegate.go @@ -15,7 +15,6 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/chains/legacyevm" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" - "github.com/smartcontractkit/chainlink/v2/core/services/pg" ) type Delegate struct { @@ -36,7 +35,7 @@ func (d *Delegate) AfterJobCreated(jb job.Job) {} func (d *Delegate) BeforeJobDeleted(spec job.Job) {} -func (d *Delegate) OnDeleteJob(ctx context.Context, jb job.Job, q pg.Queryer) error { return nil } +func (d *Delegate) OnDeleteJob(context.Context, job.Job) error { return nil } // ServicesForSpec satisfies the job.Delegate interface. func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) ([]job.ServiceCtx, error) { diff --git a/core/store/migrate/migrate_test.go b/core/store/migrate/migrate_test.go index 286e1b3a295..b3a15123efa 100644 --- a/core/store/migrate/migrate_test.go +++ b/core/store/migrate/migrate_test.go @@ -78,14 +78,15 @@ func TestMigrate_0100_BootstrapConfigs(t *testing.T) { err := goose.UpTo(db.DB, migrationDir, 99) require.NoError(t, err) - pipelineORM := pipeline.NewORM(db, lggr, cfg.Database(), cfg.JobPipeline().MaxSuccessfulRuns()) - pipelineID, err := pipelineORM.CreateSpec(pipeline.Pipeline{}, 0) + pipelineORM := pipeline.NewORM(db, lggr, cfg.JobPipeline().MaxSuccessfulRuns()) + ctx := testutils.Context(t) + pipelineID, err := pipelineORM.CreateSpec(ctx, nil, pipeline.Pipeline{}, 0) require.NoError(t, err) - pipelineID2, err := pipelineORM.CreateSpec(pipeline.Pipeline{}, 0) + pipelineID2, err := pipelineORM.CreateSpec(ctx, nil, pipeline.Pipeline{}, 0) require.NoError(t, err) - nonBootstrapPipelineID, err := pipelineORM.CreateSpec(pipeline.Pipeline{}, 0) + nonBootstrapPipelineID, err := pipelineORM.CreateSpec(ctx, nil, pipeline.Pipeline{}, 0) require.NoError(t, err) - newFormatBoostrapPipelineID2, err := pipelineORM.CreateSpec(pipeline.Pipeline{}, 0) + newFormatBoostrapPipelineID2, err := pipelineORM.CreateSpec(ctx, nil, pipeline.Pipeline{}, 0) require.NoError(t, err) // OCR2 struct at migration v0099 diff --git a/core/web/pipeline_runs_controller.go b/core/web/pipeline_runs_controller.go index 2c6caa648fc..1bd52b021c3 100644 --- a/core/web/pipeline_runs_controller.go +++ b/core/web/pipeline_runs_controller.go @@ -66,6 +66,7 @@ func (prc *PipelineRunsController) Index(c *gin.Context, size, page, offset int) // Example: // "GET /jobs/:ID/runs/:runID" func (prc *PipelineRunsController) Show(c *gin.Context) { + ctx := c.Request.Context() pipelineRun := pipeline.Run{} err := pipelineRun.SetID(c.Param("runID")) if err != nil { @@ -73,7 +74,7 @@ func (prc *PipelineRunsController) Show(c *gin.Context) { return } - pipelineRun, err = prc.App.PipelineORM().FindRun(pipelineRun.ID) + pipelineRun, err = prc.App.PipelineORM().FindRun(ctx, pipelineRun.ID) if err != nil { jsonAPIError(c, http.StatusInternalServerError, err) return @@ -87,8 +88,9 @@ func (prc *PipelineRunsController) Show(c *gin.Context) { // Example: // "POST /jobs/:ID/runs" func (prc *PipelineRunsController) Create(c *gin.Context) { + ctx := c.Request.Context() respondWithPipelineRun := func(jobRunID int64) { - pipelineRun, err := prc.App.PipelineORM().FindRun(jobRunID) + pipelineRun, err := prc.App.PipelineORM().FindRun(ctx, jobRunID) if err != nil { jsonAPIError(c, http.StatusInternalServerError, err) return diff --git a/core/web/resolver/job_run_test.go b/core/web/resolver/job_run_test.go index 18036311155..a35a2f66ac5 100644 --- a/core/web/resolver/job_run_test.go +++ b/core/web/resolver/job_run_test.go @@ -286,7 +286,7 @@ func TestResolver_RunJob(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("RunJobV2", mock.Anything, id, (map[string]interface{})(nil)).Return(int64(25), nil) - f.Mocks.pipelineORM.On("FindRun", int64(25)).Return(pipeline.Run{ + f.Mocks.pipelineORM.On("FindRun", mock.Anything, int64(25)).Return(pipeline.Run{ ID: 2, PipelineSpecID: 5, CreatedAt: f.Timestamp(), @@ -377,7 +377,7 @@ func TestResolver_RunJob(t *testing.T) { authenticated: true, before: func(f *gqlTestFramework) { f.App.On("RunJobV2", mock.Anything, id, (map[string]interface{})(nil)).Return(int64(25), nil) - f.Mocks.pipelineORM.On("FindRun", int64(25)).Return(pipeline.Run{}, gError) + f.Mocks.pipelineORM.On("FindRun", mock.Anything, int64(25)).Return(pipeline.Run{}, gError) f.App.On("PipelineORM").Return(f.Mocks.pipelineORM) }, query: mutation, diff --git a/core/web/resolver/mutation.go b/core/web/resolver/mutation.go index 85f3407169e..551b8d8e89a 100644 --- a/core/web/resolver/mutation.go +++ b/core/web/resolver/mutation.go @@ -1162,7 +1162,7 @@ func (r *Resolver) RunJob(ctx context.Context, args struct { return nil, err } - plnRun, err := r.App.PipelineORM().FindRun(jobRunID) + plnRun, err := r.App.PipelineORM().FindRun(ctx, jobRunID) if err != nil { return nil, err } From e77c458e537789be64f1af298150df5304ee789a Mon Sep 17 00:00:00 2001 From: Bartek Tofel Date: Wed, 17 Apr 2024 18:05:43 +0200 Subject: [PATCH 10/19] [TT-1078] run id removal fix (#12838) * use latest Seth * pass run_id from logging config * fail test on purpose * remove on demand failure * use tagged CTF version --- integration-tests/docker/test_env/test_env.go | 3 ++- integration-tests/docker/test_env/test_env_builder.go | 2 ++ integration-tests/go.mod | 2 +- integration-tests/go.sum | 4 ++-- integration-tests/load/go.mod | 2 +- integration-tests/load/go.sum | 4 ++-- 6 files changed, 10 insertions(+), 7 deletions(-) diff --git a/integration-tests/docker/test_env/test_env.go b/integration-tests/docker/test_env/test_env.go index cbcb943e695..d54f51eeea1 100644 --- a/integration-tests/docker/test_env/test_env.go +++ b/integration-tests/docker/test_env/test_env.go @@ -34,6 +34,7 @@ type CLClusterTestEnv struct { Cfg *TestEnvConfig DockerNetwork *tc.DockerNetwork LogStream *logstream.LogStream + TestConfig core_testconfig.GlobalTestConfig /* components */ ClCluster *ClCluster @@ -201,7 +202,7 @@ func (te *CLClusterTestEnv) Terminate() error { func (te *CLClusterTestEnv) Cleanup() error { te.l.Info().Msg("Cleaning up test environment") - runIdErr := runid.RemoveLocalRunId() + runIdErr := runid.RemoveLocalRunId(te.TestConfig.GetLoggingConfig().RunId) if runIdErr != nil { te.l.Warn().Msgf("Failed to remove .run.id file due to: %s (not a big deal, you can still remove it manually)", runIdErr.Error()) } diff --git a/integration-tests/docker/test_env/test_env_builder.go b/integration-tests/docker/test_env/test_env_builder.go index e057df03806..b9406bf16aa 100644 --- a/integration-tests/docker/test_env/test_env_builder.go +++ b/integration-tests/docker/test_env/test_env_builder.go @@ -224,6 +224,8 @@ func (b *CLTestEnvBuilder) Build() (*CLClusterTestEnv, error) { } } + b.te.TestConfig = b.testConfig + var err error if b.t != nil { b.te.WithTestInstance(b.t) diff --git a/integration-tests/go.mod b/integration-tests/go.mod index b3f0f7a90c6..2042495382a 100644 --- a/integration-tests/go.mod +++ b/integration-tests/go.mod @@ -25,7 +25,7 @@ require ( github.com/slack-go/slack v0.12.2 github.com/smartcontractkit/chainlink-automation v1.0.3 github.com/smartcontractkit/chainlink-common v0.1.7-0.20240415164156-8872a8f311cb - github.com/smartcontractkit/chainlink-testing-framework v1.28.2 + github.com/smartcontractkit/chainlink-testing-framework v1.28.3 github.com/smartcontractkit/chainlink-vrf v0.0.0-20240222010609-cd67d123c772 github.com/smartcontractkit/chainlink/v2 v2.0.0-00010101000000-000000000000 github.com/smartcontractkit/libocr v0.0.0-20240326191951-2bbe9382d052 diff --git a/integration-tests/go.sum b/integration-tests/go.sum index 957e3988143..c054bcc21b0 100644 --- a/integration-tests/go.sum +++ b/integration-tests/go.sum @@ -1533,8 +1533,8 @@ github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240216142700-c5869534c19 github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240216142700-c5869534c19e/go.mod h1:JiykN+8W5TA4UD2ClrzQCVvcH3NcyLEVv7RwY0busrw= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240325075535-0f7eb05ee595 h1:y6ks0HsSOhPUueOmTcoxDQ50RCS1XINlRDTemZyHjFw= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240325075535-0f7eb05ee595/go.mod h1:vV6WfnVIbK5Q1JsIru4YcTG0T1uRpLJm6t2BgCnCSsg= -github.com/smartcontractkit/chainlink-testing-framework v1.28.2 h1:H4/RW9J3EmHi4uJUxREJHkxxBbKRRk/eO3YhuR9a9zI= -github.com/smartcontractkit/chainlink-testing-framework v1.28.2/go.mod h1:jN+HgXbriq6fKRlIqLw9F3I81aYImV6kBJkIfz0mdIA= +github.com/smartcontractkit/chainlink-testing-framework v1.28.3 h1:rZ622PUSE9jJvI2g1SNNcMJedXyMzq9XJ8SbV2j9TvA= +github.com/smartcontractkit/chainlink-testing-framework v1.28.3/go.mod h1:jN+HgXbriq6fKRlIqLw9F3I81aYImV6kBJkIfz0mdIA= github.com/smartcontractkit/chainlink-vrf v0.0.0-20240222010609-cd67d123c772 h1:LQmRsrzzaYYN3wEU1l5tWiccznhvbyGnu2N+wHSXZAo= github.com/smartcontractkit/chainlink-vrf v0.0.0-20240222010609-cd67d123c772/go.mod h1:Kn1Hape05UzFZ7bOUnm3GVsHzP0TNrVmpfXYNHdqGGs= github.com/smartcontractkit/go-plugin v0.0.0-20240208201424-b3b91517de16 h1:TFe+FvzxClblt6qRfqEhUfa4kFQx5UobuoFGO2W4mMo= diff --git a/integration-tests/load/go.mod b/integration-tests/load/go.mod index cbde975afcc..882a998feec 100644 --- a/integration-tests/load/go.mod +++ b/integration-tests/load/go.mod @@ -17,7 +17,7 @@ require ( github.com/slack-go/slack v0.12.2 github.com/smartcontractkit/chainlink-automation v1.0.3 github.com/smartcontractkit/chainlink-common v0.1.7-0.20240415164156-8872a8f311cb - github.com/smartcontractkit/chainlink-testing-framework v1.28.2 + github.com/smartcontractkit/chainlink-testing-framework v1.28.3 github.com/smartcontractkit/chainlink/integration-tests v0.0.0-20240214231432-4ad5eb95178c github.com/smartcontractkit/chainlink/v2 v2.9.0-beta0.0.20240216210048-da02459ddad8 github.com/smartcontractkit/libocr v0.0.0-20240326191951-2bbe9382d052 diff --git a/integration-tests/load/go.sum b/integration-tests/load/go.sum index 2ae24681388..d36ce67be0c 100644 --- a/integration-tests/load/go.sum +++ b/integration-tests/load/go.sum @@ -1516,8 +1516,8 @@ github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240216142700-c5869534c19 github.com/smartcontractkit/chainlink-solana v1.0.3-0.20240216142700-c5869534c19e/go.mod h1:JiykN+8W5TA4UD2ClrzQCVvcH3NcyLEVv7RwY0busrw= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240325075535-0f7eb05ee595 h1:y6ks0HsSOhPUueOmTcoxDQ50RCS1XINlRDTemZyHjFw= github.com/smartcontractkit/chainlink-starknet/relayer v0.0.1-beta-test.0.20240325075535-0f7eb05ee595/go.mod h1:vV6WfnVIbK5Q1JsIru4YcTG0T1uRpLJm6t2BgCnCSsg= -github.com/smartcontractkit/chainlink-testing-framework v1.28.2 h1:H4/RW9J3EmHi4uJUxREJHkxxBbKRRk/eO3YhuR9a9zI= -github.com/smartcontractkit/chainlink-testing-framework v1.28.2/go.mod h1:jN+HgXbriq6fKRlIqLw9F3I81aYImV6kBJkIfz0mdIA= +github.com/smartcontractkit/chainlink-testing-framework v1.28.3 h1:rZ622PUSE9jJvI2g1SNNcMJedXyMzq9XJ8SbV2j9TvA= +github.com/smartcontractkit/chainlink-testing-framework v1.28.3/go.mod h1:jN+HgXbriq6fKRlIqLw9F3I81aYImV6kBJkIfz0mdIA= github.com/smartcontractkit/chainlink-testing-framework/grafana v0.0.0-20240227164431-18a7065e23ea h1:ZdLmNAfKRjH8AYUvjiiDGUgiWQfq/7iNpxyTkvjx/ko= github.com/smartcontractkit/chainlink-testing-framework/grafana v0.0.0-20240227164431-18a7065e23ea/go.mod h1:gCKC9w6XpNk6jm+XIk2psrkkfxhi421N9NSiFceXW88= github.com/smartcontractkit/chainlink-vrf v0.0.0-20240222010609-cd67d123c772 h1:LQmRsrzzaYYN3wEU1l5tWiccznhvbyGnu2N+wHSXZAo= From 44c9b40e0a77be0609c33d06c3101d8a7163c3e7 Mon Sep 17 00:00:00 2001 From: Dimitris Grigoriou Date: Wed, 17 Apr 2024 19:16:54 +0300 Subject: [PATCH 11/19] Drop unused queryTimeout config from TXM strategy (#12859) * Drop unused queryTimeout config from TXM strategy * Add changeset * Fix changeset * Fix changeset error * Add internal tag --- .changeset/new-forks-grab.md | 5 +++++ common/txmgr/strategies.go | 18 ++++++------------ core/chains/evm/txmgr/evm_tx_store_test.go | 4 ++-- core/chains/evm/txmgr/strategies_test.go | 8 ++------ core/services/blockhashstore/bhs.go | 2 +- core/services/fluxmonitorv2/delegate.go | 2 +- core/services/ocr/delegate.go | 2 +- core/services/relay/evm/evm.go | 2 +- core/services/relay/evm/functions.go | 2 +- 9 files changed, 20 insertions(+), 25 deletions(-) create mode 100644 .changeset/new-forks-grab.md diff --git a/.changeset/new-forks-grab.md b/.changeset/new-forks-grab.md new file mode 100644 index 00000000000..cb078beb29b --- /dev/null +++ b/.changeset/new-forks-grab.md @@ -0,0 +1,5 @@ +--- +"chainlink": removed +--- + +Drop unused queryTimeout config from TXM strategy #internal diff --git a/common/txmgr/strategies.go b/common/txmgr/strategies.go index 3772e6d1d20..6e037658854 100644 --- a/common/txmgr/strategies.go +++ b/common/txmgr/strategies.go @@ -3,7 +3,6 @@ package txmgr import ( "context" "fmt" - "time" "github.com/google/uuid" @@ -14,9 +13,9 @@ var _ txmgrtypes.TxStrategy = SendEveryStrategy{} // NewQueueingTxStrategy creates a new TxStrategy that drops the oldest transactions after the // queue size is exceeded if a queue size is specified, and otherwise does not drop transactions. -func NewQueueingTxStrategy(subject uuid.UUID, queueSize uint32, queryTimeout time.Duration) (strategy txmgrtypes.TxStrategy) { +func NewQueueingTxStrategy(subject uuid.UUID, queueSize uint32) (strategy txmgrtypes.TxStrategy) { if queueSize > 0 { - strategy = NewDropOldestStrategy(subject, queueSize, queryTimeout) + strategy = NewDropOldestStrategy(subject, queueSize) } else { strategy = SendEveryStrategy{} } @@ -41,15 +40,14 @@ var _ txmgrtypes.TxStrategy = DropOldestStrategy{} // DropOldestStrategy will send the newest N transactions, older ones will be // removed from the queue type DropOldestStrategy struct { - subject uuid.UUID - queueSize uint32 - queryTimeout time.Duration + subject uuid.UUID + queueSize uint32 } // NewDropOldestStrategy creates a new TxStrategy that drops the oldest transactions after the // queue size is exceeded. -func NewDropOldestStrategy(subject uuid.UUID, queueSize uint32, queryTimeout time.Duration) DropOldestStrategy { - return DropOldestStrategy{subject, queueSize, queryTimeout} +func NewDropOldestStrategy(subject uuid.UUID, queueSize uint32) DropOldestStrategy { + return DropOldestStrategy{subject, queueSize} } func (s DropOldestStrategy) Subject() uuid.NullUUID { @@ -57,10 +55,6 @@ func (s DropOldestStrategy) Subject() uuid.NullUUID { } func (s DropOldestStrategy) PruneQueue(ctx context.Context, pruneService txmgrtypes.UnstartedTxQueuePruner) (ids []int64, err error) { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, s.queryTimeout) - defer cancel() - // NOTE: We prune one less than the queue size to prevent the queue from exceeding the max queue size. Which could occur if a new transaction is added to the queue right after we prune. ids, err = pruneService.PruneUnstartedTxQueue(ctx, s.queueSize-1, s.subject) if err != nil { diff --git a/core/chains/evm/txmgr/evm_tx_store_test.go b/core/chains/evm/txmgr/evm_tx_store_test.go index 5bb131862ed..6cfc01c20d0 100644 --- a/core/chains/evm/txmgr/evm_tx_store_test.go +++ b/core/chains/evm/txmgr/evm_tx_store_test.go @@ -1841,7 +1841,7 @@ func TestORM_PruneUnstartedTxQueue(t *testing.T) { t.Run("does not prune if queue has not exceeded capacity-1", func(t *testing.T) { subject1 := uuid.New() - strategy1 := txmgrcommon.NewDropOldestStrategy(subject1, uint32(5), cfg.Database().DefaultQueryTimeout()) + strategy1 := txmgrcommon.NewDropOldestStrategy(subject1, uint32(5)) for i := 0; i < 5; i++ { mustCreateUnstartedGeneratedTx(t, txStore, fromAddress, &cltest.FixtureChainID, txRequestWithStrategy(strategy1)) } @@ -1850,7 +1850,7 @@ func TestORM_PruneUnstartedTxQueue(t *testing.T) { t.Run("prunes if queue has exceeded capacity-1", func(t *testing.T) { subject2 := uuid.New() - strategy2 := txmgrcommon.NewDropOldestStrategy(subject2, uint32(3), cfg.Database().DefaultQueryTimeout()) + strategy2 := txmgrcommon.NewDropOldestStrategy(subject2, uint32(3)) for i := 0; i < 5; i++ { mustCreateUnstartedGeneratedTx(t, txStore, fromAddress, &cltest.FixtureChainID, txRequestWithStrategy(strategy2)) } diff --git a/core/chains/evm/txmgr/strategies_test.go b/core/chains/evm/txmgr/strategies_test.go index 19f5f197289..d7f4ceaf450 100644 --- a/core/chains/evm/txmgr/strategies_test.go +++ b/core/chains/evm/txmgr/strategies_test.go @@ -11,7 +11,6 @@ import ( txmgrcommon "github.com/smartcontractkit/chainlink/v2/common/txmgr" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr/mocks" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" - "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/configtest" ) func Test_SendEveryStrategy(t *testing.T) { @@ -28,10 +27,9 @@ func Test_SendEveryStrategy(t *testing.T) { func Test_DropOldestStrategy_Subject(t *testing.T) { t.Parallel() - cfg := configtest.NewGeneralConfig(t, nil) subject := uuid.New() - s := txmgrcommon.NewDropOldestStrategy(subject, 1, cfg.Database().DefaultQueryTimeout()) + s := txmgrcommon.NewDropOldestStrategy(subject, 1) assert.True(t, s.Subject().Valid) assert.Equal(t, subject, s.Subject().UUID) @@ -39,14 +37,12 @@ func Test_DropOldestStrategy_Subject(t *testing.T) { func Test_DropOldestStrategy_PruneQueue(t *testing.T) { t.Parallel() - cfg := configtest.NewGeneralConfig(t, nil) subject := uuid.New() queueSize := uint32(2) - queryTimeout := cfg.Database().DefaultQueryTimeout() mockTxStore := mocks.NewEvmTxStore(t) t.Run("calls PrineUnstartedTxQueue for the given subject and queueSize, ignoring fromAddress", func(t *testing.T) { - strategy1 := txmgrcommon.NewDropOldestStrategy(subject, queueSize, queryTimeout) + strategy1 := txmgrcommon.NewDropOldestStrategy(subject, queueSize) mockTxStore.On("PruneUnstartedTxQueue", mock.Anything, queueSize-1, subject, mock.Anything, mock.Anything).Once().Return([]int64{1, 2}, nil) ids, err := strategy1.PruneQueue(testutils.Context(t), mockTxStore) require.NoError(t, err) diff --git a/core/services/blockhashstore/bhs.go b/core/services/blockhashstore/bhs.go index d4dd52c5661..4d1fe761c88 100644 --- a/core/services/blockhashstore/bhs.go +++ b/core/services/blockhashstore/bhs.go @@ -104,7 +104,7 @@ func (c *BulletproofBHS) Store(ctx context.Context, blockNum uint64) error { // Set a queue size of 256. At most we store the blockhash of every block, and only the // latest 256 can possibly be stored. - Strategy: txmgrcommon.NewQueueingTxStrategy(c.jobID, 256, c.dbConfig.DefaultQueryTimeout()), + Strategy: txmgrcommon.NewQueueingTxStrategy(c.jobID, 256), }) if err != nil { return errors.Wrap(err, "creating transaction") diff --git a/core/services/fluxmonitorv2/delegate.go b/core/services/fluxmonitorv2/delegate.go index ddb255800b1..72aa04c7201 100644 --- a/core/services/fluxmonitorv2/delegate.go +++ b/core/services/fluxmonitorv2/delegate.go @@ -70,7 +70,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] return nil, err } cfg := chain.Config() - strategy := txmgrcommon.NewQueueingTxStrategy(jb.ExternalJobID, cfg.FluxMonitor().DefaultTransactionQueueDepth(), cfg.Database().DefaultQueryTimeout()) + strategy := txmgrcommon.NewQueueingTxStrategy(jb.ExternalJobID, cfg.FluxMonitor().DefaultTransactionQueueDepth()) var checker txmgr.TransmitCheckerSpec if chain.Config().FluxMonitor().SimulateTransactions() { checker.CheckerType = txmgr.TransmitCheckerTypeSimulate diff --git a/core/services/ocr/delegate.go b/core/services/ocr/delegate.go index 63055543f88..b16ede8089f 100644 --- a/core/services/ocr/delegate.go +++ b/core/services/ocr/delegate.go @@ -197,7 +197,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] } cfg := chain.Config() - strategy := txmgrcommon.NewQueueingTxStrategy(jb.ExternalJobID, cfg.OCR().DefaultTransactionQueueDepth(), cfg.Database().DefaultQueryTimeout()) + strategy := txmgrcommon.NewQueueingTxStrategy(jb.ExternalJobID, cfg.OCR().DefaultTransactionQueueDepth()) var checker txmgr.TransmitCheckerSpec if chain.Config().OCR().SimulateTransactions() { diff --git a/core/services/relay/evm/evm.go b/core/services/relay/evm/evm.go index 1a09e681f8a..4f31110fda1 100644 --- a/core/services/relay/evm/evm.go +++ b/core/services/relay/evm/evm.go @@ -521,7 +521,7 @@ func newOnChainContractTransmitter(ctx context.Context, lggr logger.Logger, rarg subject = *opts.subjectID } scoped := configWatcher.chain.Config() - strategy := txmgrcommon.NewQueueingTxStrategy(subject, scoped.OCR2().DefaultTransactionQueueDepth(), scoped.Database().DefaultQueryTimeout()) + strategy := txmgrcommon.NewQueueingTxStrategy(subject, scoped.OCR2().DefaultTransactionQueueDepth()) var checker txm.TransmitCheckerSpec if configWatcher.chain.Config().OCR2().SimulateTransactions() { diff --git a/core/services/relay/evm/functions.go b/core/services/relay/evm/functions.go index ed7b247f46b..9444ab4164d 100644 --- a/core/services/relay/evm/functions.go +++ b/core/services/relay/evm/functions.go @@ -183,7 +183,7 @@ func newFunctionsContractTransmitter(ctx context.Context, contractVersion uint32 } scoped := configWatcher.chain.Config() - strategy := txmgrcommon.NewQueueingTxStrategy(rargs.ExternalJobID, scoped.OCR2().DefaultTransactionQueueDepth(), scoped.Database().DefaultQueryTimeout()) + strategy := txmgrcommon.NewQueueingTxStrategy(rargs.ExternalJobID, scoped.OCR2().DefaultTransactionQueueDepth()) var checker txm.TransmitCheckerSpec if configWatcher.chain.Config().OCR2().SimulateTransactions() { From 40064f0dfecda6e404993dff056e7a666cca7d26 Mon Sep 17 00:00:00 2001 From: amit-momin <108959691+amit-momin@users.noreply.github.com> Date: Wed, 17 Apr 2024 12:06:44 -0500 Subject: [PATCH 12/19] Update TXM method signature to match TX ID type (#12851) * Updated TXM method signature to match TX ID type * Fixed linting --- .changeset/weak-emus-reply.md | 5 +++++ common/txmgr/mocks/tx_manager.go | 8 +++---- common/txmgr/txmgr.go | 6 +++--- common/txmgr/types/mocks/tx_store.go | 8 +++---- common/txmgr/types/tx_store.go | 2 +- core/chains/evm/txmgr/evm_tx_store.go | 2 +- core/chains/evm/txmgr/evm_tx_store_test.go | 23 +++++++++++++++++++++ core/chains/evm/txmgr/mocks/evm_tx_store.go | 8 +++---- 8 files changed, 45 insertions(+), 17 deletions(-) create mode 100644 .changeset/weak-emus-reply.md diff --git a/.changeset/weak-emus-reply.md b/.changeset/weak-emus-reply.md new file mode 100644 index 00000000000..ef0c1fe4dae --- /dev/null +++ b/.changeset/weak-emus-reply.md @@ -0,0 +1,5 @@ +--- +"chainlink": minor +--- + +#internal Updated FindTxesWithAttemptsAndReceiptsByIdsAndState method signature to accept int64 for tx ID instead of big.Int diff --git a/common/txmgr/mocks/tx_manager.go b/common/txmgr/mocks/tx_manager.go index 37b0822941d..a5f05219217 100644 --- a/common/txmgr/mocks/tx_manager.go +++ b/common/txmgr/mocks/tx_manager.go @@ -184,7 +184,7 @@ func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) FindTx } // FindTxesWithAttemptsAndReceiptsByIdsAndState provides a mock function with given fields: ctx, ids, states, chainID -func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Context, ids []big.Int, states []txmgrtypes.TxState, chainID *big.Int) ([]*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error) { +func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Context, ids []int64, states []txmgrtypes.TxState, chainID *big.Int) ([]*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error) { ret := _m.Called(ctx, ids, states, chainID) if len(ret) == 0 { @@ -193,10 +193,10 @@ func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) FindTx var r0 []*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE] var r1 error - if rf, ok := ret.Get(0).(func(context.Context, []big.Int, []txmgrtypes.TxState, *big.Int) ([]*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, []int64, []txmgrtypes.TxState, *big.Int) ([]*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error)); ok { return rf(ctx, ids, states, chainID) } - if rf, ok := ret.Get(0).(func(context.Context, []big.Int, []txmgrtypes.TxState, *big.Int) []*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]); ok { + if rf, ok := ret.Get(0).(func(context.Context, []int64, []txmgrtypes.TxState, *big.Int) []*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]); ok { r0 = rf(ctx, ids, states, chainID) } else { if ret.Get(0) != nil { @@ -204,7 +204,7 @@ func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) FindTx } } - if rf, ok := ret.Get(1).(func(context.Context, []big.Int, []txmgrtypes.TxState, *big.Int) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, []int64, []txmgrtypes.TxState, *big.Int) error); ok { r1 = rf(ctx, ids, states, chainID) } else { r1 = ret.Error(1) diff --git a/common/txmgr/txmgr.go b/common/txmgr/txmgr.go index b996b76f1a5..39895941ffd 100644 --- a/common/txmgr/txmgr.go +++ b/common/txmgr/txmgr.go @@ -57,7 +57,7 @@ type TxManager[ // Find transactions with a non-null TxMeta field that was provided and a receipt block number greater than or equal to the one provided FindTxesWithMetaFieldByReceiptBlockNum(ctx context.Context, metaField string, blockNum int64, chainID *big.Int) (txes []*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) // Find transactions loaded with transaction attempts and receipts by transaction IDs and states - FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Context, ids []big.Int, states []txmgrtypes.TxState, chainID *big.Int) (txes []*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) + FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Context, ids []int64, states []txmgrtypes.TxState, chainID *big.Int) (txes []*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) FindEarliestUnconfirmedBroadcastTime(ctx context.Context) (nullv4.Time, error) FindEarliestUnconfirmedTxAttemptBlock(ctx context.Context) (nullv4.Int, error) CountTransactionsByState(ctx context.Context, state txmgrtypes.TxState) (count uint32, err error) @@ -587,7 +587,7 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) FindTxesWi return } -func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Context, ids []big.Int, states []txmgrtypes.TxState, chainID *big.Int) (txes []*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) { +func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Context, ids []int64, states []txmgrtypes.TxState, chainID *big.Int) (txes []*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) { txes, err = b.txStore.FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx, ids, states, chainID) return } @@ -667,7 +667,7 @@ func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) Fin func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) FindTxesWithMetaFieldByReceiptBlockNum(ctx context.Context, metaField string, blockNum int64, chainID *big.Int) (txes []*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) { return txes, errors.New(n.ErrMsg) } -func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Context, ids []big.Int, states []txmgrtypes.TxState, chainID *big.Int) (txes []*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) { +func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Context, ids []int64, states []txmgrtypes.TxState, chainID *big.Int) (txes []*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) { return txes, errors.New(n.ErrMsg) } diff --git a/common/txmgr/types/mocks/tx_store.go b/common/txmgr/types/mocks/tx_store.go index 814207d3986..8d70fe6b5a9 100644 --- a/common/txmgr/types/mocks/tx_store.go +++ b/common/txmgr/types/mocks/tx_store.go @@ -551,7 +551,7 @@ func (_m *TxStore[ADDR, CHAIN_ID, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) FindTxesPen } // FindTxesWithAttemptsAndReceiptsByIdsAndState provides a mock function with given fields: ctx, ids, states, chainID -func (_m *TxStore[ADDR, CHAIN_ID, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Context, ids []big.Int, states []txmgrtypes.TxState, chainID *big.Int) ([]*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error) { +func (_m *TxStore[ADDR, CHAIN_ID, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Context, ids []int64, states []txmgrtypes.TxState, chainID *big.Int) ([]*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error) { ret := _m.Called(ctx, ids, states, chainID) if len(ret) == 0 { @@ -560,10 +560,10 @@ func (_m *TxStore[ADDR, CHAIN_ID, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) FindTxesWit var r0 []*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE] var r1 error - if rf, ok := ret.Get(0).(func(context.Context, []big.Int, []txmgrtypes.TxState, *big.Int) ([]*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, []int64, []txmgrtypes.TxState, *big.Int) ([]*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error)); ok { return rf(ctx, ids, states, chainID) } - if rf, ok := ret.Get(0).(func(context.Context, []big.Int, []txmgrtypes.TxState, *big.Int) []*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]); ok { + if rf, ok := ret.Get(0).(func(context.Context, []int64, []txmgrtypes.TxState, *big.Int) []*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]); ok { r0 = rf(ctx, ids, states, chainID) } else { if ret.Get(0) != nil { @@ -571,7 +571,7 @@ func (_m *TxStore[ADDR, CHAIN_ID, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) FindTxesWit } } - if rf, ok := ret.Get(1).(func(context.Context, []big.Int, []txmgrtypes.TxState, *big.Int) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, []int64, []txmgrtypes.TxState, *big.Int) error); ok { r1 = rf(ctx, ids, states, chainID) } else { r1 = ret.Error(1) diff --git a/common/txmgr/types/tx_store.go b/common/txmgr/types/tx_store.go index f061f0ea628..43d41cb4d31 100644 --- a/common/txmgr/types/tx_store.go +++ b/common/txmgr/types/tx_store.go @@ -52,7 +52,7 @@ type TxStore[ // Find transactions with a non-null TxMeta field that was provided and a receipt block number greater than or equal to the one provided FindTxesWithMetaFieldByReceiptBlockNum(ctx context.Context, metaField string, blockNum int64, chainID *big.Int) (tx []*Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) // Find transactions loaded with transaction attempts and receipts by transaction IDs and states - FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Context, ids []big.Int, states []TxState, chainID *big.Int) (tx []*Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) + FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Context, ids []int64, states []TxState, chainID *big.Int) (tx []*Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) } // TransactionStore contains the persistence layer methods needed to manage Txs and TxAttempts diff --git a/core/chains/evm/txmgr/evm_tx_store.go b/core/chains/evm/txmgr/evm_tx_store.go index 55f650e934b..7a05f32c9e4 100644 --- a/core/chains/evm/txmgr/evm_tx_store.go +++ b/core/chains/evm/txmgr/evm_tx_store.go @@ -1939,7 +1939,7 @@ func (o *evmTxStore) FindTxesWithMetaFieldByReceiptBlockNum(ctx context.Context, } // Find transactions loaded with transaction attempts and receipts by transaction IDs and states -func (o *evmTxStore) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Context, ids []big.Int, states []txmgrtypes.TxState, chainID *big.Int) (txes []*Tx, err error) { +func (o *evmTxStore) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Context, ids []int64, states []txmgrtypes.TxState, chainID *big.Int) (txes []*Tx, err error) { var cancel context.CancelFunc ctx, cancel = o.mergeContexts(ctx) defer cancel() diff --git a/core/chains/evm/txmgr/evm_tx_store_test.go b/core/chains/evm/txmgr/evm_tx_store_test.go index 6cfc01c20d0..5d3fdcfafd8 100644 --- a/core/chains/evm/txmgr/evm_tx_store_test.go +++ b/core/chains/evm/txmgr/evm_tx_store_test.go @@ -1858,6 +1858,29 @@ func TestORM_PruneUnstartedTxQueue(t *testing.T) { }) } +func TestORM_FindTxesWithAttemptsAndReceiptsByIdsAndState(t *testing.T) { + t.Parallel() + + db := pgtest.NewSqlxDB(t) + cfg := configtest.NewGeneralConfig(t, nil) + txStore := cltest.NewTestTxStore(t, db) + ethKeyStore := cltest.NewKeyStore(t, db, cfg.Database()).Eth() + ctx := testutils.Context(t) + + _, from := cltest.MustInsertRandomKey(t, ethKeyStore) + + tx := cltest.MustInsertConfirmedEthTxWithLegacyAttempt(t, txStore, 0, 1, from) + r := newEthReceipt(4, utils.NewHash(), tx.TxAttempts[0].Hash, 0x1) + _, err := txStore.InsertReceipt(ctx, &r.Receipt) + require.NoError(t, err) + + txes, err := txStore.FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx, []int64{tx.ID}, []txmgrtypes.TxState{txmgrcommon.TxConfirmed}, testutils.FixtureChainID) + require.NoError(t, err) + require.Len(t, txes, 1) + require.Len(t, txes[0].TxAttempts, 1) + require.Len(t, txes[0].TxAttempts[0].Receipts, 1) +} + func AssertCountPerSubject(t *testing.T, txStore txmgr.TestEvmTxStore, expected int64, subject uuid.UUID) { t.Helper() count, err := txStore.CountTxesByStateAndSubject(testutils.Context(t), "unstarted", subject) diff --git a/core/chains/evm/txmgr/mocks/evm_tx_store.go b/core/chains/evm/txmgr/mocks/evm_tx_store.go index 61c948c1ff4..b6806f34d76 100644 --- a/core/chains/evm/txmgr/mocks/evm_tx_store.go +++ b/core/chains/evm/txmgr/mocks/evm_tx_store.go @@ -672,7 +672,7 @@ func (_m *EvmTxStore) FindTxesPendingCallback(ctx context.Context, blockNum int6 } // FindTxesWithAttemptsAndReceiptsByIdsAndState provides a mock function with given fields: ctx, ids, states, chainID -func (_m *EvmTxStore) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Context, ids []big.Int, states []types.TxState, chainID *big.Int) ([]*types.Tx[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee], error) { +func (_m *EvmTxStore) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.Context, ids []int64, states []types.TxState, chainID *big.Int) ([]*types.Tx[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee], error) { ret := _m.Called(ctx, ids, states, chainID) if len(ret) == 0 { @@ -681,10 +681,10 @@ func (_m *EvmTxStore) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.C var r0 []*types.Tx[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee] var r1 error - if rf, ok := ret.Get(0).(func(context.Context, []big.Int, []types.TxState, *big.Int) ([]*types.Tx[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee], error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, []int64, []types.TxState, *big.Int) ([]*types.Tx[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee], error)); ok { return rf(ctx, ids, states, chainID) } - if rf, ok := ret.Get(0).(func(context.Context, []big.Int, []types.TxState, *big.Int) []*types.Tx[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee]); ok { + if rf, ok := ret.Get(0).(func(context.Context, []int64, []types.TxState, *big.Int) []*types.Tx[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee]); ok { r0 = rf(ctx, ids, states, chainID) } else { if ret.Get(0) != nil { @@ -692,7 +692,7 @@ func (_m *EvmTxStore) FindTxesWithAttemptsAndReceiptsByIdsAndState(ctx context.C } } - if rf, ok := ret.Get(1).(func(context.Context, []big.Int, []types.TxState, *big.Int) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, []int64, []types.TxState, *big.Int) error); ok { r1 = rf(ctx, ids, states, chainID) } else { r1 = ret.Error(1) From 2f60dbee54b3e17dd6caad7dbc9bd14ff71c843b Mon Sep 17 00:00:00 2001 From: Tate Date: Wed, 17 Apr 2024 11:36:18 -0600 Subject: [PATCH 13/19] Fix e2e test log artifact name collision on upload (#12861) --- .github/workflows/integration-tests.yml | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 1386bfef8f3..98d67a8b2d3 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -379,6 +379,7 @@ jobs: cl_repo: ${{ env.CHAINLINK_IMAGE }} cl_image_tag: ${{ inputs.evm-ref || github.sha }} aws_registries: ${{ secrets.QA_AWS_ACCOUNT_NUMBER }} + artifacts_name: ${{ matrix.product.name }}-test-logs artifacts_location: | ./integration-tests/smoke/logs/ /tmp/gotest.log @@ -470,6 +471,7 @@ jobs: cl_repo: ${{ env.CHAINLINK_IMAGE }} cl_image_tag: ${{ inputs.evm-ref || github.sha }} aws_registries: ${{ secrets.QA_AWS_ACCOUNT_NUMBER }} + artifacts_name: ${{ matrix.product.name }}-test-logs artifacts_location: | ./integration-tests/smoke/logs/ /tmp/gotest.log @@ -854,6 +856,7 @@ jobs: test_download_vendor_packages_command: cd ./integration-tests && go mod download cl_repo: ${{ env.CHAINLINK_IMAGE }} cl_image_tag: ${{ steps.get_latest_version.outputs.latest_version }} + artifacts_name: node-migration-test-logs artifacts_location: | ./integration-tests/migration/logs /tmp/gotest.log @@ -865,14 +868,6 @@ jobs: QA_AWS_REGION: ${{ secrets.QA_AWS_REGION }} QA_AWS_ROLE_TO_ASSUME: ${{ secrets.QA_AWS_ROLE_TO_ASSUME }} QA_KUBECONFIG: ${{ secrets.QA_KUBECONFIG }} - - name: Upload test log - uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # v4.3.1 - if: failure() - with: - name: test-log-${{ matrix.product.name }} - path: /tmp/gotest.log - retention-days: 7 - continue-on-error: true - name: Collect Metrics if: always() id: collect-gha-metrics @@ -1160,21 +1155,17 @@ jobs: test_command_to_run: export ENV_JOB_IMAGE=${{ secrets.QA_AWS_ACCOUNT_NUMBER }}.dkr.ecr.${{ secrets.QA_AWS_REGION }}.amazonaws.com/chainlink-solana-tests:${{ needs.get_solana_sha.outputs.sha }} && make test_smoke cl_repo: ${{ env.CHAINLINK_IMAGE }} cl_image_tag: ${{ inputs.evm-ref || github.sha }} - artifacts_location: /home/runner/work/chainlink-solana/chainlink-solana/integration-tests/logs publish_check_name: Solana Smoke Test Results go_mod_path: ./integration-tests/go.mod cache_key_id: core-solana-e2e-${{ env.MOD_CACHE_VERSION }} token: ${{ secrets.GITHUB_TOKEN }} aws_registries: ${{ secrets.QA_AWS_ACCOUNT_NUMBER }} + artifacts_name: solana-test-logs + artifacts_location: | + ./integration-tests/smoke/logs + /tmp/gotest.log QA_AWS_REGION: ${{ secrets.QA_AWS_REGION }} QA_AWS_ROLE_TO_ASSUME: ${{ secrets.QA_AWS_ROLE_TO_ASSUME }} QA_KUBECONFIG: "" run_setup: false - - name: Upload test log - uses: actions/upload-artifact@5d5d22a31266ced268874388b861e4b58bb5c2f3 # v4.3.1 - if: failure() - with: - name: test-log-solana - path: /tmp/gotest.log - retention-days: 7 - continue-on-error: true + From b468bc961941cbbfa5de980a970f6c6a35365d7f Mon Sep 17 00:00:00 2001 From: Austin Born Date: Wed, 17 Apr 2024 12:54:27 -0700 Subject: [PATCH 14/19] Update CODEOWNERS for OperatorForwarder contracts (#12868) --- CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CODEOWNERS b/CODEOWNERS index 65f0b58753f..8741cb7a685 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -75,7 +75,7 @@ core/scripts/gateway @smartcontractkit/functions /contracts/src/v0.8/l2ep @chris-de-leon-cll /contracts/src/v0.8/llo-feeds @smartcontractkit/mercury-team # TODO: mocks folder, folder should be removed and files moved to the correct folders -/contracts/src/v0.8/operatorforwarder @RensR +/contracts/src/v0.8/operatorforwarder @austinborn /contracts/src/v0.8/shared @RensR # TODO: tests folder, folder should be removed and files moved to the correct folders # TODO: transmission folder, owner should be found From ccb8cd85fef8e3bbe3fb5580277a7bd7f477e6bb Mon Sep 17 00:00:00 2001 From: Dylan Tinianov Date: Wed, 17 Apr 2024 16:09:25 -0400 Subject: [PATCH 15/19] Fix and re-enable abandoned tx tracker (#12533) * Re-enable tracker * generate * update tracker * fix race conditions * Update tracker.go * Update tracker.go * Ensure thread safety * Update tracker.go * concurrently track txes * Optimizations * update logging * Update tracker.go * Update CHANGELOG.md * Update common/txmgr/tracker.go Co-authored-by: Jim W * GetAbandonedTransactions * lint * enabled address check * Update tracker_test.go * Update tracker.go * Ignore confirmed txes * Don't resend abandoned txes * Remove comment * Update tracker_test.go * Remove unused block height * changeset * Update tracker_test.go --------- Co-authored-by: Jim W Co-authored-by: Prashant Yadav <34992934+prashantkumar1982@users.noreply.github.com> --- .changeset/gold-bottles-tell.md | 5 + common/txmgr/resender.go | 3 - common/txmgr/tracker.go | 257 ++++++++++---------- common/txmgr/txmgr.go | 43 ++-- common/txmgr/types/mocks/tx_store.go | 48 ++-- common/txmgr/types/tx_store.go | 2 +- core/chains/evm/txmgr/evm_tx_store.go | 34 +-- core/chains/evm/txmgr/evm_tx_store_test.go | 39 ++- core/chains/evm/txmgr/mocks/evm_tx_store.go | 48 ++-- core/chains/evm/txmgr/tracker_test.go | 48 ++-- 10 files changed, 276 insertions(+), 251 deletions(-) create mode 100644 .changeset/gold-bottles-tell.md diff --git a/.changeset/gold-bottles-tell.md b/.changeset/gold-bottles-tell.md new file mode 100644 index 00000000000..5289f368a55 --- /dev/null +++ b/.changeset/gold-bottles-tell.md @@ -0,0 +1,5 @@ +--- +"chainlink": minor +--- + +#added : Re-enable abandoned transaction tracker diff --git a/common/txmgr/resender.go b/common/txmgr/resender.go index 8c2dd6b827e..b752ec63f13 100644 --- a/common/txmgr/resender.go +++ b/common/txmgr/resender.go @@ -140,9 +140,6 @@ func (er *Resender[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) resendUnco return fmt.Errorf("Resender failed getting enabled keys for chain %s: %w", er.chainID.String(), err) } - // Tracker currently disabled for BCI-2638; refactor required - // resendAddresses = append(resendAddresses, er.tracker.GetAbandonedAddresses()...) - ageThreshold := er.txConfig.ResendAfterThreshold() maxInFlightTransactions := er.txConfig.MaxInFlight() olderThan := time.Now().Add(-ageThreshold) diff --git a/common/txmgr/tracker.go b/common/txmgr/tracker.go index c63d9c264fc..a7236472710 100644 --- a/common/txmgr/tracker.go +++ b/common/txmgr/tracker.go @@ -8,6 +8,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" feetypes "github.com/smartcontractkit/chainlink/v2/common/fee/types" @@ -22,17 +23,10 @@ const ( // handleTxesTimeout represents a sanity limit on how long handleTxesByState // should take to complete handleTxesTimeout = 10 * time.Minute + // batchSize is the number of txes to fetch from the txStore at once + batchSize = 1000 ) -// AbandonedTx is a transaction who's 'FromAddress' was removed from the KeyStore(by the Node Operator). -// Thus, any new attempts for this Tx can't be signed/created. This means no fee bumping can be done. -// However, the Tx may still have live attempts in the chain's mempool, and could get confirmed on the -// chain as-is. Thus, the TXM should not directly discard this Tx. -type AbandonedTx[ADDR types.Hashable] struct { - id int64 - fromAddress ADDR -} - // Tracker tracks all transactions which have abandoned fromAddresses. // The fromAddresses can be deleted by Node Operators from the KeyStore. In such cases, // existing in-flight transactions for these fromAddresses are considered abandoned too. @@ -48,19 +42,22 @@ type Tracker[ FEE feetypes.Fee, ] struct { services.StateMachine - txStore txmgrtypes.TxStore[ADDR, CHAIN_ID, TX_HASH, BLOCK_HASH, R, SEQ, FEE] - keyStore txmgrtypes.KeyStore[ADDR, CHAIN_ID, SEQ] - chainID CHAIN_ID - lggr logger.Logger - enabledAddrs map[ADDR]bool - txCache map[int64]AbandonedTx[ADDR] - ttl time.Duration + txStore txmgrtypes.TxStore[ADDR, CHAIN_ID, TX_HASH, BLOCK_HASH, R, SEQ, FEE] + keyStore txmgrtypes.KeyStore[ADDR, CHAIN_ID, SEQ] + chainID CHAIN_ID + lggr logger.Logger + lock sync.Mutex - mb *mailbox.Mailbox[int64] - wg sync.WaitGroup - isStarted bool - ctx context.Context - ctxCancel context.CancelFunc + enabledAddrs map[ADDR]bool + txCache map[int64]ADDR // cache tx fromAddress by txID + + ttl time.Duration + mb *mailbox.Mailbox[int64] + + initSync sync.Mutex + wg sync.WaitGroup + chStop services.StopChan + isStarted bool } func NewTracker[ @@ -83,7 +80,7 @@ func NewTracker[ chainID: chainID, lggr: logger.Named(lggr, "TxMgrTracker"), enabledAddrs: map[ADDR]bool{}, - txCache: map[int64]AbandonedTx[ADDR]{}, + txCache: map[int64]ADDR{}, ttl: defaultTTL, mb: mailbox.NewSingle[int64](), lock: sync.Mutex{}, @@ -99,75 +96,84 @@ func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Start(ctx c } func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) startInternal(ctx context.Context) (err error) { - tr.lock.Lock() - defer tr.lock.Unlock() - - tr.ctx, tr.ctxCancel = context.WithCancel(context.Background()) + tr.initSync.Lock() + defer tr.initSync.Unlock() if err := tr.setEnabledAddresses(ctx); err != nil { return fmt.Errorf("failed to set enabled addresses: %w", err) } - tr.lggr.Info("Enabled addresses set") + tr.lggr.Infof("enabled addresses set for chainID %v", tr.chainID) - if err := tr.trackAbandonedTxes(ctx); err != nil { - return fmt.Errorf("failed to track abandoned txes: %w", err) - } - - tr.isStarted = true - if len(tr.txCache) == 0 { - tr.lggr.Info("no abandoned txes found, skipping runLoop") - return nil - } - - tr.lggr.Infof("%d abandoned txes found, starting runLoop", len(tr.txCache)) + tr.chStop = make(chan struct{}) tr.wg.Add(1) - go tr.runLoop() + go tr.runLoop(tr.chStop.NewCtx()) + tr.isStarted = true return nil } func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Close() error { - tr.lock.Lock() - defer tr.lock.Unlock() return tr.StopOnce("Tracker", func() error { return tr.closeInternal() }) } func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) closeInternal() error { + tr.initSync.Lock() + defer tr.initSync.Unlock() + tr.lggr.Info("stopping tracker") if !tr.isStarted { - return fmt.Errorf("tracker not started") + return fmt.Errorf("tracker is not started: %w", services.ErrAlreadyStopped) } - tr.ctxCancel() + + close(tr.chStop) tr.wg.Wait() tr.isStarted = false return nil } -func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() { +func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop(ctx context.Context, cancel context.CancelFunc) { defer tr.wg.Done() + defer cancel() + + if err := tr.trackAbandonedTxes(ctx); err != nil { + tr.lggr.Errorf("failed to track abandoned txes: %v", err) + return + } + if err := tr.handleTxesByState(ctx); err != nil { + tr.lggr.Errorf("failed to handle txes by state: %v", err) + return + } + if tr.AbandonedTxCount() == 0 { + tr.lggr.Info("no abandoned txes found, skipping runLoop") + return + } + tr.lggr.Infof("%d abandoned txes found, starting runLoop", tr.AbandonedTxCount()) + ttlExceeded := time.NewTicker(tr.ttl) defer ttlExceeded.Stop() for { select { case <-tr.mb.Notify(): for { - if tr.ctx.Err() != nil { - return - } - blockHeight, exists := tr.mb.Retrieve() - if !exists { + blockHeight := tr.mb.RetrieveLatestAndClear() + if blockHeight == 0 { break } - if err := tr.HandleTxesByState(tr.ctx, blockHeight); err != nil { - tr.lggr.Errorw(fmt.Errorf("failed to handle txes by state: %w", err).Error()) + if err := tr.handleTxesByState(ctx); err != nil { + tr.lggr.Errorf("failed to handle txes by state: %v", err) + return + } + if tr.AbandonedTxCount() == 0 { + tr.lggr.Info("all abandoned txes handled, stopping runLoop") + return } } case <-ttlExceeded.C: tr.lggr.Info("ttl exceeded") - tr.MarkAllTxesFatal(tr.ctx) + tr.markAllTxesFatal(ctx) return - case <-tr.ctx.Done(): + case <-ctx.Done(): return } } @@ -177,24 +183,31 @@ func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetAbandone tr.lock.Lock() defer tr.lock.Unlock() - if !tr.isStarted { - return []ADDR{} - } - abandonedAddrs := make([]ADDR, len(tr.txCache)) - for _, atx := range tr.txCache { - abandonedAddrs = append(abandonedAddrs, atx.fromAddress) + for _, fromAddress := range tr.txCache { + abandonedAddrs = append(abandonedAddrs, fromAddress) } return abandonedAddrs } -func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) IsStarted() bool { +// AbandonedTxCount returns the number of abandoned txes currently being tracked +func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) AbandonedTxCount() int { tr.lock.Lock() defer tr.lock.Unlock() + return len(tr.txCache) +} + +func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) IsStarted() bool { + tr.initSync.Lock() + defer tr.initSync.Unlock() return tr.isStarted } +// setEnabledAddresses is called on startup to set the enabled addresses for the chain func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) setEnabledAddresses(ctx context.Context) error { + tr.lock.Lock() + defer tr.lock.Unlock() + enabledAddrs, err := tr.keyStore.EnabledAddressesForChain(ctx, tr.chainID) if err != nil { return fmt.Errorf("failed to get enabled addresses for chain: %w", err) @@ -210,54 +223,58 @@ func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) setEnabledA return nil } -// trackAbandonedTxes called once to find and insert all abandoned txes into the tracker. +// trackAbandonedTxes called on startup to find and insert all abandoned txes into the tracker. func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) trackAbandonedTxes(ctx context.Context) (err error) { - if tr.isStarted { - return fmt.Errorf("tracker already started") - } - - tr.lggr.Info("Retrieving non fatal transactions from txStore") - nonFatalTxes, err := tr.txStore.GetNonFatalTransactions(ctx, tr.chainID) - if err != nil { - return fmt.Errorf("failed to get non fatal txes from txStore: %w", err) - } - - // insert abandoned txes - for _, tx := range nonFatalTxes { - if !tr.enabledAddrs[tx.FromAddress] { - tr.insertTx(tx) + return sqlutil.Batch(func(offset, limit uint) (count uint, err error) { + var enabledAddrs []ADDR + for addr := range tr.enabledAddrs { + enabledAddrs = append(enabledAddrs, addr) } - } - if err := tr.handleTxesByState(ctx, 0); err != nil { - return fmt.Errorf("failed to handle txes by state: %w", err) - } - - return nil + nonFatalTxes, err := tr.txStore.GetAbandonedTransactionsByBatch(ctx, tr.chainID, enabledAddrs, offset, limit) + if err != nil { + return 0, fmt.Errorf("failed to get non fatal txes from txStore: %w", err) + } + // insert abandoned txes + tr.lock.Lock() + for _, tx := range nonFatalTxes { + if !tr.enabledAddrs[tx.FromAddress] { + tr.txCache[tx.ID] = tx.FromAddress + tr.lggr.Debugf("inserted tx %v", tx.ID) + } + } + tr.lock.Unlock() + return uint(len(nonFatalTxes)), nil + }, batchSize) } -func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) HandleTxesByState(ctx context.Context, blockHeight int64) error { +// handleTxesByState handles all txes in the txCache by their state +// It's called on every new blockHeight and also on startup to handle all txes in the txCache +func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) handleTxesByState(ctx context.Context) error { tr.lock.Lock() defer tr.lock.Unlock() - tr.ctx, tr.ctxCancel = context.WithTimeout(ctx, handleTxesTimeout) - defer tr.ctxCancel() - return tr.handleTxesByState(ctx, blockHeight) -} + ctx, cancel := context.WithTimeout(ctx, handleTxesTimeout) + defer cancel() -func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) handleTxesByState(ctx context.Context, blockHeight int64) error { - tr.lggr.Info("Handling transactions by state") + for id := range tr.txCache { + if ctx.Err() != nil { + return ctx.Err() + } - for id, atx := range tr.txCache { - tx, err := tr.txStore.GetTxByID(ctx, atx.id) + tx, err := tr.txStore.GetTxByID(ctx, id) if err != nil { - return fmt.Errorf("failed to get tx by ID: %w", err) + tr.lggr.Errorf("failed to get tx by ID: %v", err) + continue + } + if tx == nil { + tr.lggr.Warnf("tx with ID %v no longer exists, removing from tracker", id) + delete(tr.txCache, id) + continue } switch tx.State { case TxConfirmed: - if err := tr.handleConfirmedTx(tx, blockHeight); err != nil { - return fmt.Errorf("failed to handle confirmed txes: %w", err) - } + // TODO: Handle finalized state https://smartcontract-it.atlassian.net/browse/BCI-2920 case TxConfirmedMissingReceipt, TxUnconfirmed: // Keep tracking tx case TxInProgress, TxUnstarted: @@ -266,50 +283,20 @@ func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) handleTxesB // is deleted, we can't sign it. errMsg := "The FromAddress for this Tx was deleted before this Tx could be broadcast to the chain." if err := tr.markTxFatal(ctx, tx, errMsg); err != nil { - return fmt.Errorf("failed to mark tx as fatal: %w", err) + tr.lggr.Errorf("failed to mark tx as fatal: %v", err) + continue } delete(tr.txCache, id) case TxFatalError: delete(tr.txCache, id) default: - tr.lggr.Errorw(fmt.Sprintf("unhandled transaction state: %v", tx.State)) + tr.lggr.Errorf("unhandled transaction state: %v", tx.State) } } return nil } -// handleConfirmedTx removes a transaction from the tracker if it's been finalized on chain -func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) handleConfirmedTx( - tx *txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], - blockHeight int64, -) error { - finalized, err := tr.txStore.IsTxFinalized(tr.ctx, blockHeight, tx.ID, tr.chainID) - if err != nil { - return fmt.Errorf("failed to check if tx is finalized: %w", err) - } - - if finalized { - delete(tr.txCache, tx.ID) - } - - return nil -} - -// insertTx inserts a transaction into the tracker as an AbandonedTx -func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) insertTx( - tx *txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) { - if _, contains := tr.txCache[tx.ID]; contains { - return - } - - tr.txCache[tx.ID] = AbandonedTx[ADDR]{ - id: tx.ID, - fromAddress: tx.FromAddress, - } - tr.lggr.Debugw(fmt.Sprintf("inserted tx %v", tx.ID)) -} - func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) markTxFatal(ctx context.Context, tx *txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], errMsg string) error { @@ -323,22 +310,26 @@ func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) markTxFatal return nil } -func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) MarkAllTxesFatal(ctx context.Context) { +// markAllTxesFatal tries to mark all txes in the txCache as fatal and removes them from the cache +func (tr *Tracker[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) markAllTxesFatal(ctx context.Context) { tr.lock.Lock() defer tr.lock.Unlock() + errMsg := fmt.Sprintf( - "fromAddress for this Tx was deleted, and existing attempts onchain didn't finalize within %d hours, thus this Tx was abandoned.", + "tx abandoned: fromAddress for this tx was deleted and existing attempts didn't finalize onchain within %d hours", int(tr.ttl.Hours())) - for _, atx := range tr.txCache { - tx, err := tr.txStore.GetTxByID(ctx, atx.id) + for id := range tr.txCache { + tx, err := tr.txStore.GetTxByID(ctx, id) if err != nil { - tr.lggr.Errorw(fmt.Errorf("failed to get tx by ID: %w", err).Error()) + tr.lggr.Errorf("failed to get tx by ID: %v", err) + delete(tr.txCache, id) continue } if err := tr.markTxFatal(ctx, tx, errMsg); err != nil { - tr.lggr.Errorw(fmt.Errorf("failed to mark tx as abandoned: %w", err).Error()) + tr.lggr.Errorf("failed to mark tx as abandoned: %v", err) } + delete(tr.txCache, id) } } diff --git a/common/txmgr/txmgr.go b/common/txmgr/txmgr.go index 39895941ffd..4d4eabe5c40 100644 --- a/common/txmgr/txmgr.go +++ b/common/txmgr/txmgr.go @@ -190,12 +190,9 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Start(ctx return fmt.Errorf("Txm: Estimator failed to start: %w", err) } - /* Tracker currently disabled for BCI-2638; refactor required - b.logger.Info("Txm starting tracker") if err := ms.Start(ctx, b.tracker); err != nil { return fmt.Errorf("Txm: Tracker failed to start: %w", err) } - */ b.logger.Info("Txm starting runLoop") b.wg.Add(1) @@ -275,12 +272,6 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) Close() (m merr = errors.Join(merr, fmt.Errorf("Txm: failed to close TxAttemptBuilder: %w", err)) } - /* Tracker currently disabled for BCI-2638; refactor required - if err := b.tracker.Close(); err != nil { - merr = errors.Join(merr, fmt.Errorf("Txm: failed to close Tracker: %w", err)) - } - */ - return nil }) } @@ -329,6 +320,9 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() if err := b.broadcaster.closeInternal(); err != nil { b.logger.Panicw(fmt.Sprintf("Failed to Close Broadcaster: %v", err), "err", err) } + if err := b.tracker.closeInternal(); err != nil { + b.logger.Panicw(fmt.Sprintf("Failed to Close Tracker: %v", err), "err", err) + } if err := b.confirmer.closeInternal(); err != nil { b.logger.Panicw(fmt.Sprintf("Failed to Close Confirmer: %v", err), "err", err) } @@ -337,16 +331,17 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() close(r.done) } var wg sync.WaitGroup - // two goroutines to handle independent backoff retries starting: + // three goroutines to handle independent backoff retries starting: // - Broadcaster // - Confirmer + // - Tracker // If chStop is closed, we mark stopped=true so that the main runloop // can check and exit early if necessary // // execReset will not return until either: - // 1. Both Broadcaster and Confirmer started successfully + // 1. Broadcaster, Confirmer, and Tracker all started successfully // 2. chStop was closed (txmgr exit) - wg.Add(2) + wg.Add(3) go func() { defer wg.Done() // Retry indefinitely on failure @@ -366,6 +361,25 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() } } }() + go func() { + defer wg.Done() + // Retry indefinitely on failure + backoff := iutils.NewRedialBackoff() + for { + select { + case <-time.After(backoff.Duration()): + if err := b.tracker.startInternal(ctx); err != nil { + b.logger.Criticalw("Failed to start Tracker", "err", err) + b.SvcErrBuffer.Append(err) + continue + } + return + case <-b.chStop: + stopOnce.Do(func() { stopped = true }) + return + } + } + }() go func() { defer wg.Done() // Retry indefinitely on failure @@ -395,8 +409,7 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() b.broadcaster.Trigger(address) case head := <-b.chHeads: b.confirmer.mb.Deliver(head) - // Tracker currently disabled for BCI-2638; refactor required - // b.tracker.mb.Deliver(head.BlockNumber()) + b.tracker.mb.Deliver(head.BlockNumber()) case reset := <-b.reset: // This check prevents the weird edge-case where you can select // into this block after chStop has already been closed and the @@ -424,12 +437,10 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) runLoop() if err != nil && (!errors.Is(err, services.ErrAlreadyStopped) || !errors.Is(err, services.ErrCannotStopUnstarted)) { b.logger.Errorw(fmt.Sprintf("Failed to Close Confirmer: %v", err), "err", err) } - /* Tracker currently disabled for BCI-2638; refactor required err = b.tracker.Close() if err != nil && (!errors.Is(err, services.ErrAlreadyStopped) || !errors.Is(err, services.ErrCannotStopUnstarted)) { b.logger.Errorw(fmt.Sprintf("Failed to Close Tracker: %v", err), "err", err) } - */ return case <-keysChanged: // This check prevents the weird edge-case where you can select diff --git a/common/txmgr/types/mocks/tx_store.go b/common/txmgr/types/mocks/tx_store.go index 8d70fe6b5a9..64193afff5b 100644 --- a/common/txmgr/types/mocks/tx_store.go +++ b/common/txmgr/types/mocks/tx_store.go @@ -700,29 +700,29 @@ func (_m *TxStore[ADDR, CHAIN_ID, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) FindTxsRequ return r0, r1 } -// GetInProgressTxAttempts provides a mock function with given fields: ctx, address, chainID -func (_m *TxStore[ADDR, CHAIN_ID, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetInProgressTxAttempts(ctx context.Context, address ADDR, chainID CHAIN_ID) ([]txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error) { - ret := _m.Called(ctx, address, chainID) +// GetAbandonedTransactionsByBatch provides a mock function with given fields: ctx, chainID, enabledAddrs, offset, limit +func (_m *TxStore[ADDR, CHAIN_ID, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetAbandonedTransactionsByBatch(ctx context.Context, chainID CHAIN_ID, enabledAddrs []ADDR, offset uint, limit uint) ([]*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error) { + ret := _m.Called(ctx, chainID, enabledAddrs, offset, limit) if len(ret) == 0 { - panic("no return value specified for GetInProgressTxAttempts") + panic("no return value specified for GetAbandonedTransactionsByBatch") } - var r0 []txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE] + var r0 []*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE] var r1 error - if rf, ok := ret.Get(0).(func(context.Context, ADDR, CHAIN_ID) ([]txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error)); ok { - return rf(ctx, address, chainID) + if rf, ok := ret.Get(0).(func(context.Context, CHAIN_ID, []ADDR, uint, uint) ([]*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error)); ok { + return rf(ctx, chainID, enabledAddrs, offset, limit) } - if rf, ok := ret.Get(0).(func(context.Context, ADDR, CHAIN_ID) []txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]); ok { - r0 = rf(ctx, address, chainID) + if rf, ok := ret.Get(0).(func(context.Context, CHAIN_ID, []ADDR, uint, uint) []*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]); ok { + r0 = rf(ctx, chainID, enabledAddrs, offset, limit) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) + r0 = ret.Get(0).([]*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) } } - if rf, ok := ret.Get(1).(func(context.Context, ADDR, CHAIN_ID) error); ok { - r1 = rf(ctx, address, chainID) + if rf, ok := ret.Get(1).(func(context.Context, CHAIN_ID, []ADDR, uint, uint) error); ok { + r1 = rf(ctx, chainID, enabledAddrs, offset, limit) } else { r1 = ret.Error(1) } @@ -730,29 +730,29 @@ func (_m *TxStore[ADDR, CHAIN_ID, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetInProgre return r0, r1 } -// GetNonFatalTransactions provides a mock function with given fields: ctx, chainID -func (_m *TxStore[ADDR, CHAIN_ID, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetNonFatalTransactions(ctx context.Context, chainID CHAIN_ID) ([]*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error) { - ret := _m.Called(ctx, chainID) +// GetInProgressTxAttempts provides a mock function with given fields: ctx, address, chainID +func (_m *TxStore[ADDR, CHAIN_ID, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetInProgressTxAttempts(ctx context.Context, address ADDR, chainID CHAIN_ID) ([]txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error) { + ret := _m.Called(ctx, address, chainID) if len(ret) == 0 { - panic("no return value specified for GetNonFatalTransactions") + panic("no return value specified for GetInProgressTxAttempts") } - var r0 []*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE] + var r0 []txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE] var r1 error - if rf, ok := ret.Get(0).(func(context.Context, CHAIN_ID) ([]*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error)); ok { - return rf(ctx, chainID) + if rf, ok := ret.Get(0).(func(context.Context, ADDR, CHAIN_ID) ([]txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], error)); ok { + return rf(ctx, address, chainID) } - if rf, ok := ret.Get(0).(func(context.Context, CHAIN_ID) []*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]); ok { - r0 = rf(ctx, chainID) + if rf, ok := ret.Get(0).(func(context.Context, ADDR, CHAIN_ID) []txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]); ok { + r0 = rf(ctx, address, chainID) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) + r0 = ret.Get(0).([]txmgrtypes.TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) } } - if rf, ok := ret.Get(1).(func(context.Context, CHAIN_ID) error); ok { - r1 = rf(ctx, chainID) + if rf, ok := ret.Get(1).(func(context.Context, ADDR, CHAIN_ID) error); ok { + r1 = rf(ctx, address, chainID) } else { r1 = ret.Error(1) } diff --git a/common/txmgr/types/tx_store.go b/common/txmgr/types/tx_store.go index 43d41cb4d31..bca2d1e3647 100644 --- a/common/txmgr/types/tx_store.go +++ b/common/txmgr/types/tx_store.go @@ -85,7 +85,7 @@ type TransactionStore[ FindEarliestUnconfirmedTxAttemptBlock(ctx context.Context, chainID CHAIN_ID) (null.Int, error) GetTxInProgress(ctx context.Context, fromAddress ADDR) (etx *Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) GetInProgressTxAttempts(ctx context.Context, address ADDR, chainID CHAIN_ID) (attempts []TxAttempt[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) - GetNonFatalTransactions(ctx context.Context, chainID CHAIN_ID) (txs []*Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) + GetAbandonedTransactionsByBatch(ctx context.Context, chainID CHAIN_ID, enabledAddrs []ADDR, offset, limit uint) (txs []*Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) GetTxByID(ctx context.Context, id int64) (tx *Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) HasInProgressTransaction(ctx context.Context, account ADDR, chainID CHAIN_ID) (exists bool, err error) LoadTxAttempts(ctx context.Context, etx *Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) error diff --git a/core/chains/evm/txmgr/evm_tx_store.go b/core/chains/evm/txmgr/evm_tx_store.go index 7a05f32c9e4..c8e664e8cfe 100644 --- a/core/chains/evm/txmgr/evm_tx_store.go +++ b/core/chains/evm/txmgr/evm_tx_store.go @@ -1292,26 +1292,28 @@ func (o *evmTxStore) SaveInProgressAttempt(ctx context.Context, attempt *TxAttem return nil } -func (o *evmTxStore) GetNonFatalTransactions(ctx context.Context, chainID *big.Int) (txes []*Tx, err error) { +func (o *evmTxStore) GetAbandonedTransactionsByBatch(ctx context.Context, chainID *big.Int, enabledAddrs []common.Address, offset, limit uint) (txes []*Tx, err error) { var cancel context.CancelFunc ctx, cancel = o.mergeContexts(ctx) defer cancel() - err = o.Transaction(ctx, true, func(orm *evmTxStore) error { - stmt := `SELECT * FROM evm.txes WHERE state <> 'fatal_error' AND evm_chain_id = $1` - var dbEtxs []DbEthTx - if err = orm.q.SelectContext(ctx, &dbEtxs, stmt, chainID.String()); err != nil { - return fmt.Errorf("failed to load evm.txes: %w", err) - } - txes = make([]*Tx, len(dbEtxs)) - dbEthTxsToEvmEthTxPtrs(dbEtxs, txes) - err = o.LoadTxesAttempts(ctx, txes) - if err != nil { - return fmt.Errorf("failed to load evm.txes: %w", err) - } - return nil - }) - return txes, nil + var enabledAddrsBytea [][]byte + for _, addr := range enabledAddrs { + enabledAddrsBytea = append(enabledAddrsBytea, addr[:]) + } + + // TODO: include confirmed txes https://smartcontract-it.atlassian.net/browse/BCI-2920 + query := `SELECT * FROM evm.txes WHERE state <> 'fatal_error' AND state <> 'confirmed' AND evm_chain_id = $1 + AND from_address <> ALL($2) ORDER BY nonce ASC OFFSET $3 LIMIT $4` + + var dbEtxs []DbEthTx + if err = o.q.SelectContext(ctx, &dbEtxs, query, chainID.String(), enabledAddrsBytea, offset, limit); err != nil { + return nil, fmt.Errorf("failed to load evm.txes: %w", err) + } + txes = make([]*Tx, len(dbEtxs)) + dbEthTxsToEvmEthTxPtrs(dbEtxs, txes) + + return txes, err } func (o *evmTxStore) GetTxByID(ctx context.Context, id int64) (txe *Tx, err error) { diff --git a/core/chains/evm/txmgr/evm_tx_store_test.go b/core/chains/evm/txmgr/evm_tx_store_test.go index 5d3fdcfafd8..ff5c7ec4abc 100644 --- a/core/chains/evm/txmgr/evm_tx_store_test.go +++ b/core/chains/evm/txmgr/evm_tx_store_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/ethereum/go-ethereum/common" + commonconfig "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" @@ -1470,7 +1472,7 @@ func TestORM_GetTxInProgress(t *testing.T) { }) } -func TestORM_GetNonFatalTransactions(t *testing.T) { +func TestORM_GetAbandonedTransactionsByBatch(t *testing.T) { t.Parallel() db := pgtest.NewSqlxDB(t) @@ -1479,9 +1481,19 @@ func TestORM_GetNonFatalTransactions(t *testing.T) { ethKeyStore := cltest.NewKeyStore(t, db, cfg.Database()).Eth() ethClient := evmtest.NewEthClientMockWithDefaultChain(t) _, fromAddress := cltest.MustInsertRandomKeyReturningState(t, ethKeyStore) + _, enabled := cltest.MustInsertRandomKeyReturningState(t, ethKeyStore) + enabledAddrs := []common.Address{enabled} + + t.Run("get 0 abandoned transactions", func(t *testing.T) { + txes, err := txStore.GetAbandonedTransactionsByBatch(testutils.Context(t), ethClient.ConfiguredChainID(), enabledAddrs, 0, 10) + require.NoError(t, err) + require.Empty(t, txes) + }) - t.Run("gets 0 non finalized eth transaction", func(t *testing.T) { - txes, err := txStore.GetNonFatalTransactions(testutils.Context(t), ethClient.ConfiguredChainID()) + t.Run("do not return enabled addresses", func(t *testing.T) { + _ = mustInsertInProgressEthTxWithAttempt(t, txStore, 123, enabled) + _ = mustCreateUnstartedGeneratedTx(t, txStore, enabled, ethClient.ConfiguredChainID()) + txes, err := txStore.GetAbandonedTransactionsByBatch(testutils.Context(t), ethClient.ConfiguredChainID(), enabledAddrs, 0, 10) require.NoError(t, err) require.Empty(t, txes) }) @@ -1490,13 +1502,32 @@ func TestORM_GetNonFatalTransactions(t *testing.T) { inProgressTx := mustInsertInProgressEthTxWithAttempt(t, txStore, 123, fromAddress) unstartedTx := mustCreateUnstartedGeneratedTx(t, txStore, fromAddress, ethClient.ConfiguredChainID()) - txes, err := txStore.GetNonFatalTransactions(testutils.Context(t), ethClient.ConfiguredChainID()) + txes, err := txStore.GetAbandonedTransactionsByBatch(testutils.Context(t), ethClient.ConfiguredChainID(), enabledAddrs, 0, 10) require.NoError(t, err) + require.Len(t, txes, 2) for _, tx := range txes { require.True(t, tx.ID == inProgressTx.ID || tx.ID == unstartedTx.ID) } }) + + t.Run("get batches of transactions", func(t *testing.T) { + var batchSize uint = 10 + numTxes := 55 + for i := 0; i < numTxes; i++ { + _ = mustCreateUnstartedGeneratedTx(t, txStore, fromAddress, ethClient.ConfiguredChainID()) + } + + allTxes := make([]*txmgr.Tx, 0) + err := sqlutil.Batch(func(offset, limit uint) (count uint, err error) { + batchTxes, err := txStore.GetAbandonedTransactionsByBatch(testutils.Context(t), ethClient.ConfiguredChainID(), enabledAddrs, offset, limit) + require.NoError(t, err) + allTxes = append(allTxes, batchTxes...) + return uint(len(batchTxes)), nil + }, batchSize) + require.NoError(t, err) + require.Len(t, allTxes, numTxes+2) + }) } func TestORM_GetTxByID(t *testing.T) { diff --git a/core/chains/evm/txmgr/mocks/evm_tx_store.go b/core/chains/evm/txmgr/mocks/evm_tx_store.go index b6806f34d76..a05f2a22c60 100644 --- a/core/chains/evm/txmgr/mocks/evm_tx_store.go +++ b/core/chains/evm/txmgr/mocks/evm_tx_store.go @@ -821,29 +821,29 @@ func (_m *EvmTxStore) FindTxsRequiringResubmissionDueToInsufficientFunds(ctx con return r0, r1 } -// GetInProgressTxAttempts provides a mock function with given fields: ctx, address, chainID -func (_m *EvmTxStore) GetInProgressTxAttempts(ctx context.Context, address common.Address, chainID *big.Int) ([]types.TxAttempt[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee], error) { - ret := _m.Called(ctx, address, chainID) +// GetAbandonedTransactionsByBatch provides a mock function with given fields: ctx, chainID, enabledAddrs, offset, limit +func (_m *EvmTxStore) GetAbandonedTransactionsByBatch(ctx context.Context, chainID *big.Int, enabledAddrs []common.Address, offset uint, limit uint) ([]*types.Tx[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee], error) { + ret := _m.Called(ctx, chainID, enabledAddrs, offset, limit) if len(ret) == 0 { - panic("no return value specified for GetInProgressTxAttempts") + panic("no return value specified for GetAbandonedTransactionsByBatch") } - var r0 []types.TxAttempt[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee] + var r0 []*types.Tx[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee] var r1 error - if rf, ok := ret.Get(0).(func(context.Context, common.Address, *big.Int) ([]types.TxAttempt[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee], error)); ok { - return rf(ctx, address, chainID) + if rf, ok := ret.Get(0).(func(context.Context, *big.Int, []common.Address, uint, uint) ([]*types.Tx[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee], error)); ok { + return rf(ctx, chainID, enabledAddrs, offset, limit) } - if rf, ok := ret.Get(0).(func(context.Context, common.Address, *big.Int) []types.TxAttempt[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee]); ok { - r0 = rf(ctx, address, chainID) + if rf, ok := ret.Get(0).(func(context.Context, *big.Int, []common.Address, uint, uint) []*types.Tx[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee]); ok { + r0 = rf(ctx, chainID, enabledAddrs, offset, limit) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]types.TxAttempt[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee]) + r0 = ret.Get(0).([]*types.Tx[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee]) } } - if rf, ok := ret.Get(1).(func(context.Context, common.Address, *big.Int) error); ok { - r1 = rf(ctx, address, chainID) + if rf, ok := ret.Get(1).(func(context.Context, *big.Int, []common.Address, uint, uint) error); ok { + r1 = rf(ctx, chainID, enabledAddrs, offset, limit) } else { r1 = ret.Error(1) } @@ -851,29 +851,29 @@ func (_m *EvmTxStore) GetInProgressTxAttempts(ctx context.Context, address commo return r0, r1 } -// GetNonFatalTransactions provides a mock function with given fields: ctx, chainID -func (_m *EvmTxStore) GetNonFatalTransactions(ctx context.Context, chainID *big.Int) ([]*types.Tx[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee], error) { - ret := _m.Called(ctx, chainID) +// GetInProgressTxAttempts provides a mock function with given fields: ctx, address, chainID +func (_m *EvmTxStore) GetInProgressTxAttempts(ctx context.Context, address common.Address, chainID *big.Int) ([]types.TxAttempt[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee], error) { + ret := _m.Called(ctx, address, chainID) if len(ret) == 0 { - panic("no return value specified for GetNonFatalTransactions") + panic("no return value specified for GetInProgressTxAttempts") } - var r0 []*types.Tx[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee] + var r0 []types.TxAttempt[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee] var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *big.Int) ([]*types.Tx[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee], error)); ok { - return rf(ctx, chainID) + if rf, ok := ret.Get(0).(func(context.Context, common.Address, *big.Int) ([]types.TxAttempt[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee], error)); ok { + return rf(ctx, address, chainID) } - if rf, ok := ret.Get(0).(func(context.Context, *big.Int) []*types.Tx[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee]); ok { - r0 = rf(ctx, chainID) + if rf, ok := ret.Get(0).(func(context.Context, common.Address, *big.Int) []types.TxAttempt[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee]); ok { + r0 = rf(ctx, address, chainID) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*types.Tx[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee]) + r0 = ret.Get(0).([]types.TxAttempt[*big.Int, common.Address, common.Hash, common.Hash, evmtypes.Nonce, gas.EvmFee]) } } - if rf, ok := ret.Get(1).(func(context.Context, *big.Int) error); ok { - r1 = rf(ctx, chainID) + if rf, ok := ret.Get(1).(func(context.Context, common.Address, *big.Int) error); ok { + r1 = rf(ctx, address, chainID) } else { r1 = ret.Error(1) } diff --git a/core/chains/evm/txmgr/tracker_test.go b/core/chains/evm/txmgr/tracker_test.go index e95c005dc77..eefd89c69eb 100644 --- a/core/chains/evm/txmgr/tracker_test.go +++ b/core/chains/evm/txmgr/tracker_test.go @@ -1,7 +1,6 @@ package txmgr_test import ( - "context" "math/big" "testing" "time" @@ -10,6 +9,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" ubig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" "github.com/smartcontractkit/chainlink/v2/core/internal/cltest" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/evmtest" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils/pgtest" "github.com/smartcontractkit/chainlink/v2/core/logger" @@ -44,25 +44,23 @@ func containsID(txes []*txmgr.Tx, id int64) bool { } func TestEvmTracker_Initialization(t *testing.T) { - t.Skip("BCI-2638 tracker disabled") t.Parallel() tracker, _, _, _ := newTestEvmTrackerSetup(t) + ctx := testutils.Context(t) - err := tracker.Start(context.Background()) - require.NoError(t, err) + require.NoError(t, tracker.Start(ctx)) require.True(t, tracker.IsStarted()) t.Run("stop tracker", func(t *testing.T) { - err := tracker.Close() - require.NoError(t, err) + require.NoError(t, tracker.Close()) require.False(t, tracker.IsStarted()) }) } func TestEvmTracker_AddressTracking(t *testing.T) { - t.Skip("BCI-2638 tracker disabled") t.Parallel() + ctx := testutils.Context(t) t.Run("track abandoned addresses", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) @@ -76,33 +74,37 @@ func TestEvmTracker_AddressTracking(t *testing.T) { _ = mustInsertConfirmedEthTxWithReceipt(t, txStore, confirmedAddr, 123, 1) _ = mustCreateUnstartedTx(t, txStore, unstartedAddr, cltest.MustGenerateRandomKey(t).Address, []byte{}, 0, big.Int{}, ethClient.ConfiguredChainID()) - err := tracker.Start(context.Background()) + err := tracker.Start(ctx) require.NoError(t, err) defer func(tracker *txmgr.Tracker) { err = tracker.Close() require.NoError(t, err) }(tracker) + time.Sleep(waitTime) addrs := tracker.GetAbandonedAddresses() require.NotContains(t, addrs, inProgressAddr) require.NotContains(t, addrs, unstartedAddr) - require.Contains(t, addrs, confirmedAddr) require.Contains(t, addrs, unconfirmedAddr) }) + /* TODO: finalized tx state https://smartcontract-it.atlassian.net/browse/BCI-2920 t.Run("stop tracking finalized tx", func(t *testing.T) { - t.Skip("BCI-2638 tracker disabled") tracker, txStore, _, _ := newTestEvmTrackerSetup(t) confirmedAddr := cltest.MustGenerateRandomKey(t).Address _ = mustInsertConfirmedEthTxWithReceipt(t, txStore, confirmedAddr, 123, 1) - err := tracker.Start(context.Background()) + err := tracker.Start(ctx) require.NoError(t, err) defer func(tracker *txmgr.Tracker) { err = tracker.Close() require.NoError(t, err) }(tracker) + // deliver block before minConfirmations + tracker.XXXDeliverBlock(1) + time.Sleep(waitTime) + addrs := tracker.GetAbandonedAddresses() require.Contains(t, addrs, confirmedAddr) @@ -113,26 +115,12 @@ func TestEvmTracker_AddressTracking(t *testing.T) { addrs = tracker.GetAbandonedAddresses() require.NotContains(t, addrs, confirmedAddr) }) + */ } func TestEvmTracker_ExceedingTTL(t *testing.T) { - t.Skip("BCI-2638 tracker disabled") t.Parallel() - - t.Run("confirmed but unfinalized transaction still tracked", func(t *testing.T) { - tracker, txStore, _, _ := newTestEvmTrackerSetup(t) - addr1 := cltest.MustGenerateRandomKey(t).Address - _ = mustInsertConfirmedEthTxWithReceipt(t, txStore, addr1, 123, 1) - - err := tracker.Start(context.Background()) - require.NoError(t, err) - defer func(tracker *txmgr.Tracker) { - err = tracker.Close() - require.NoError(t, err) - }(tracker) - - require.Contains(t, tracker.GetAbandonedAddresses(), addr1) - }) + ctx := testutils.Context(t) t.Run("exceeding ttl", func(t *testing.T) { tracker, txStore, _, _ := newTestEvmTrackerSetup(t) @@ -142,17 +130,17 @@ func TestEvmTracker_ExceedingTTL(t *testing.T) { tx2 := cltest.MustInsertUnconfirmedEthTx(t, txStore, 123, addr2) tracker.XXXTestSetTTL(time.Nanosecond) - err := tracker.Start(context.Background()) + err := tracker.Start(ctx) require.NoError(t, err) defer func(tracker *txmgr.Tracker) { err = tracker.Close() require.NoError(t, err) }(tracker) - time.Sleep(waitTime) + time.Sleep(100 * waitTime) require.NotContains(t, tracker.GetAbandonedAddresses(), addr1, addr2) - fatalTxes, err := txStore.GetFatalTransactions(context.Background()) + fatalTxes, err := txStore.GetFatalTransactions(ctx) require.NoError(t, err) require.True(t, containsID(fatalTxes, tx1.ID)) require.True(t, containsID(fatalTxes, tx2.ID)) From ae1f53f993c915e9ceed79a37479f9e5977099c3 Mon Sep 17 00:00:00 2001 From: Finley Decker Date: Wed, 17 Apr 2024 21:01:05 -1000 Subject: [PATCH 16/19] fix install (#12872) --- operator_ui/check.sh | 2 +- operator_ui/install.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/operator_ui/check.sh b/operator_ui/check.sh index 9e738218088..e4e12209b05 100755 --- a/operator_ui/check.sh +++ b/operator_ui/check.sh @@ -6,7 +6,7 @@ set -e # jq ^1.6 https://stedolan.github.io/jq/ repo=smartcontractkit/operator-ui -gitRoot=$(git rev-parse --show-toplevel) +gitRoot="$(dirname -- "$0")/../" cd "$gitRoot/operator_ui" tag_file=TAG diff --git a/operator_ui/install.sh b/operator_ui/install.sh index 0de72d51f4e..f86c9a2f352 100755 --- a/operator_ui/install.sh +++ b/operator_ui/install.sh @@ -4,7 +4,7 @@ set -e owner=smartcontractkit repo=operator-ui fullRepo=${owner}/${repo} -gitRoot=$(git rev-parse --show-toplevel || pwd) +gitRoot="$(dirname -- "$0")/../" cd "$gitRoot/operator_ui" unpack_dir="$gitRoot/core/web/assets" tag=$(cat TAG) From 057ef81edccbd34526f5b712f837a17082e94bc1 Mon Sep 17 00:00:00 2001 From: Lee Yik Jiun Date: Thu, 18 Apr 2024 18:16:17 +0800 Subject: [PATCH 17/19] Move vrfv2plus wrapper events from interface to implementation (#12875) --- .../src/v0.8/vrf/dev/VRFV2PlusWrapper.sol | 20 +++++++++++++++++++ .../vrf/dev/interfaces/IVRFV2PlusWrapper.sol | 20 ------------------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/contracts/src/v0.8/vrf/dev/VRFV2PlusWrapper.sol b/contracts/src/v0.8/vrf/dev/VRFV2PlusWrapper.sol index 77c225cdb02..4a806db5515 100644 --- a/contracts/src/v0.8/vrf/dev/VRFV2PlusWrapper.sol +++ b/contracts/src/v0.8/vrf/dev/VRFV2PlusWrapper.sol @@ -32,6 +32,26 @@ contract VRFV2PlusWrapper is ConfirmedOwner, TypeAndVersionInterface, VRFConsume LinkTokenInterface internal immutable i_link; AggregatorV3Interface internal immutable i_link_native_feed; + event FulfillmentTxSizeSet(uint32 size); + event ConfigSet( + uint32 wrapperGasOverhead, + uint32 coordinatorGasOverhead, + uint16 coordinatorGasOverheadPerWord, + uint8 coordinatorNativePremiumPercentage, + uint8 coordinatorLinkPremiumPercentage, + bytes32 keyHash, + uint8 maxNumWords, + uint32 stalenessSeconds, + int256 fallbackWeiPerUnitLink, + uint32 fulfillmentFlatFeeNativePPM, + uint32 fulfillmentFlatFeeLinkDiscountPPM + ); + event FallbackWeiPerUnitLinkUsed(uint256 requestId, int256 fallbackWeiPerUnitLink); + event Withdrawn(address indexed to, uint256 amount); + event NativeWithdrawn(address indexed to, uint256 amount); + event Enabled(); + event Disabled(); + error LinkAlreadySet(); error LinkDiscountTooHigh(uint32 flatFeeLinkDiscountPPM, uint32 flatFeeNativePPM); error InvalidPremiumPercentage(uint8 premiumPercentage, uint8 max); diff --git a/contracts/src/v0.8/vrf/dev/interfaces/IVRFV2PlusWrapper.sol b/contracts/src/v0.8/vrf/dev/interfaces/IVRFV2PlusWrapper.sol index 917c59433ef..85b0c47659d 100644 --- a/contracts/src/v0.8/vrf/dev/interfaces/IVRFV2PlusWrapper.sol +++ b/contracts/src/v0.8/vrf/dev/interfaces/IVRFV2PlusWrapper.sol @@ -2,26 +2,6 @@ pragma solidity ^0.8.0; interface IVRFV2PlusWrapper { - event FulfillmentTxSizeSet(uint32 size); - event ConfigSet( - uint32 wrapperGasOverhead, - uint32 coordinatorGasOverhead, - uint16 coordinatorGasOverheadPerWord, - uint8 coordinatorNativePremiumPercentage, - uint8 coordinatorLinkPremiumPercentage, - bytes32 keyHash, - uint8 maxNumWords, - uint32 stalenessSeconds, - int256 fallbackWeiPerUnitLink, - uint32 fulfillmentFlatFeeNativePPM, - uint32 fulfillmentFlatFeeLinkDiscountPPM - ); - event FallbackWeiPerUnitLinkUsed(uint256 requestId, int256 fallbackWeiPerUnitLink); - event Withdrawn(address indexed to, uint256 amount); - event NativeWithdrawn(address indexed to, uint256 amount); - event Enabled(); - event Disabled(); - /** * @return the request ID of the most recent VRF V2 request made by this wrapper. This should only * be relied option within the same transaction that the request was made. From 6991af26d9fa0e048b72a05f4f9c13f2306c0328 Mon Sep 17 00:00:00 2001 From: Silas Lenihan <32529249+silaslenihan@users.noreply.github.com> Date: Thu, 18 Apr 2024 11:23:55 -0400 Subject: [PATCH 18/19] Implement the gas estimator components refactor (#12650) * Applied LimitMultiplier to chainSpecificGasLimit in all gas estimators * Intial refactor * Revert "Applied LimitMultiplier to chainSpecificGasLimit in all gas estimators" This reverts commit e0c581c9e58a93ed64714d0b445eeffe120e4c2a. * Refactored L1GasOracle * Fixed failing tests * fixed linting issues * Updated ethClient struct to be specific to l1Oracle * Removed changes to broadcaster error handling * fixed rollup tests * Fixed remaining issues with PR * updated changesets * fixed linting issues * minor fixes * Deduplicated L1Oracle creation * Changed L1Oracle to return pointer * Update core/chains/evm/gas/rollups/arbitrum_l1_oracle.go Co-authored-by: amit-momin <108959691+amit-momin@users.noreply.github.com> * Added l1Oracle as a parameter to pass into gas estimators * removed casting for direct call * removed unused mocks --------- Co-authored-by: amit-momin <108959691+amit-momin@users.noreply.github.com> --- .changeset/flat-guests-marry.md | 6 + core/chains/evm/gas/arbitrum_estimator.go | 74 +-- .../chains/evm/gas/arbitrum_estimator_test.go | 119 ++--- .../chains/evm/gas/block_history_estimator.go | 14 +- .../evm/gas/block_history_estimator_test.go | 168 +++++-- core/chains/evm/gas/cmd/arbgas/main.go | 85 ---- core/chains/evm/gas/fixed_price_estimator.go | 10 +- .../evm/gas/fixed_price_estimator_test.go | 21 +- core/chains/evm/gas/mocks/eth_client.go | 61 --- core/chains/evm/gas/mocks/evm_estimator.go | 22 + .../evm/gas/mocks/fee_estimator_client.go | 154 +++++++ core/chains/evm/gas/mocks/rpc_client.go | 49 -- core/chains/evm/gas/models.go | 81 ++-- core/chains/evm/gas/models_test.go | 56 ++- .../evm/gas/rollups/arbitrum_l1_oracle.go | 300 ++++++++++++ core/chains/evm/gas/rollups/l1_oracle.go | 346 +------------- core/chains/evm/gas/rollups/l1_oracle_test.go | 92 +++- .../evm/gas/rollups/mocks/da_price_reader.go | 59 --- .../{eth_client.go => l1_oracle_client.go} | 16 +- core/chains/evm/gas/rollups/models.go | 22 - core/chains/evm/gas/rollups/op_l1_oracle.go | 431 ++++++++++++++++++ ...ce_reader_test.go => op_l1_oracle_test.go} | 18 +- .../chains/evm/gas/rollups/op_price_reader.go | 228 --------- .../evm/gas/suggested_price_estimator.go | 15 +- .../evm/gas/suggested_price_estimator_test.go | 117 +++-- core/chains/evm/txmgr/broadcaster_test.go | 18 +- core/chains/evm/txmgr/confirmer_test.go | 12 +- 27 files changed, 1443 insertions(+), 1151 deletions(-) create mode 100644 .changeset/flat-guests-marry.md delete mode 100644 core/chains/evm/gas/cmd/arbgas/main.go delete mode 100644 core/chains/evm/gas/mocks/eth_client.go create mode 100644 core/chains/evm/gas/mocks/fee_estimator_client.go delete mode 100644 core/chains/evm/gas/mocks/rpc_client.go create mode 100644 core/chains/evm/gas/rollups/arbitrum_l1_oracle.go delete mode 100644 core/chains/evm/gas/rollups/mocks/da_price_reader.go rename core/chains/evm/gas/rollups/mocks/{eth_client.go => l1_oracle_client.go} (72%) delete mode 100644 core/chains/evm/gas/rollups/models.go create mode 100644 core/chains/evm/gas/rollups/op_l1_oracle.go rename core/chains/evm/gas/rollups/{op_price_reader_test.go => op_l1_oracle_test.go} (90%) delete mode 100644 core/chains/evm/gas/rollups/op_price_reader.go diff --git a/.changeset/flat-guests-marry.md b/.changeset/flat-guests-marry.md new file mode 100644 index 00000000000..c1eb6549a96 --- /dev/null +++ b/.changeset/flat-guests-marry.md @@ -0,0 +1,6 @@ +--- +"chainlink": minor +--- + +#internal Gas Estimator L1Oracles to be chain specific +#removed cmd/arbgas diff --git a/core/chains/evm/gas/arbitrum_estimator.go b/core/chains/evm/gas/arbitrum_estimator.go index 40366c5b998..0cd4bbcdd0b 100644 --- a/core/chains/evm/gas/arbitrum_estimator.go +++ b/core/chains/evm/gas/arbitrum_estimator.go @@ -3,14 +3,10 @@ package gas import ( "context" "fmt" - "math" - "math/big" "slices" "sync" "time" - "github.com/ethereum/go-ethereum" - "github.com/ethereum/go-ethereum/common" pkgerrors "github.com/pkg/errors" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -19,7 +15,7 @@ import ( feetypes "github.com/smartcontractkit/chainlink/v2/common/fee/types" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" - evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas/rollups" ) type ArbConfig interface { @@ -28,11 +24,6 @@ type ArbConfig interface { BumpMin() *assets.Wei } -//go:generate mockery --quiet --name ethClient --output ./mocks/ --case=underscore --structname ETHClient -type ethClient interface { - CallContract(ctx context.Context, msg ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) -} - // arbitrumEstimator is an Estimator which extends SuggestedPriceEstimator to use getPricesInArbGas() for gas limit estimation. type arbitrumEstimator struct { services.StateMachine @@ -40,7 +31,6 @@ type arbitrumEstimator struct { EvmEstimator // *SuggestedPriceEstimator - client ethClient pollPeriod time.Duration logger logger.Logger @@ -52,20 +42,23 @@ type arbitrumEstimator struct { chInitialised chan struct{} chStop services.StopChan chDone chan struct{} + + l1Oracle rollups.ArbL1GasOracle } -func NewArbitrumEstimator(lggr logger.Logger, cfg ArbConfig, rpcClient rpcClient, ethClient ethClient) EvmEstimator { +func NewArbitrumEstimator(lggr logger.Logger, cfg ArbConfig, ethClient feeEstimatorClient, l1Oracle rollups.ArbL1GasOracle) EvmEstimator { lggr = logger.Named(lggr, "ArbitrumEstimator") + return &arbitrumEstimator{ cfg: cfg, - EvmEstimator: NewSuggestedPriceEstimator(lggr, rpcClient, cfg), - client: ethClient, + EvmEstimator: NewSuggestedPriceEstimator(lggr, ethClient, cfg, l1Oracle), pollPeriod: 10 * time.Second, logger: lggr, chForceRefetch: make(chan (chan struct{})), chInitialised: make(chan struct{}), chStop: make(chan struct{}), chDone: make(chan struct{}), + l1Oracle: l1Oracle, } } @@ -196,7 +189,7 @@ func (a *arbitrumEstimator) run() { func (a *arbitrumEstimator) refreshPricesInArbGas() (t *time.Timer) { t = time.NewTimer(utils.WithJitter(a.pollPeriod)) - perL2Tx, perL1CalldataUnit, err := a.callGetPricesInArbGas() + perL2Tx, perL1CalldataUnit, err := a.l1Oracle.GetPricesInArbGas() if err != nil { a.logger.Warnw("Failed to refresh prices", "err", err) return @@ -210,54 +203,3 @@ func (a *arbitrumEstimator) refreshPricesInArbGas() (t *time.Timer) { a.getPricesInArbGasMu.Unlock() return } - -const ( - // ArbGasInfoAddress is the address of the "Precompiled contract that exists in every Arbitrum chain." - // https://github.com/OffchainLabs/nitro/blob/f7645453cfc77bf3e3644ea1ac031eff629df325/contracts/src/precompiles/ArbGasInfo.sol - ArbGasInfoAddress = "0x000000000000000000000000000000000000006C" - // ArbGasInfo_getPricesInArbGas is the a hex encoded call to: - // `function getPricesInArbGas() external view returns (uint256, uint256, uint256);` - ArbGasInfo_getPricesInArbGas = "02199f34" -) - -// callGetPricesInArbGas calls ArbGasInfo.getPricesInArbGas() on the precompile contract ArbGasInfoAddress. -// -// @return (per L2 tx, per L1 calldata unit, per storage allocation) -// function getPricesInArbGas() external view returns (uint256, uint256, uint256); -// -// https://github.com/OffchainLabs/nitro/blob/f7645453cfc77bf3e3644ea1ac031eff629df325/contracts/src/precompiles/ArbGasInfo.sol#L69 -func (a *arbitrumEstimator) callGetPricesInArbGas() (perL2Tx uint32, perL1CalldataUnit uint32, err error) { - ctx, cancel := a.chStop.CtxCancel(evmclient.ContextWithDefaultTimeout()) - defer cancel() - - precompile := common.HexToAddress(ArbGasInfoAddress) - b, err := a.client.CallContract(ctx, ethereum.CallMsg{ - To: &precompile, - Data: common.Hex2Bytes(ArbGasInfo_getPricesInArbGas), - }, big.NewInt(-1)) - if err != nil { - return 0, 0, err - } - - if len(b) != 3*32 { // returns (uint256, uint256, uint256); - err = fmt.Errorf("return data length (%d) different than expected (%d)", len(b), 3*32) - return - } - bPerL2Tx := new(big.Int).SetBytes(b[:32]) - bPerL1CalldataUnit := new(big.Int).SetBytes(b[32:64]) - // ignore perStorageAllocation - if !bPerL2Tx.IsUint64() || !bPerL1CalldataUnit.IsUint64() { - err = fmt.Errorf("returned integers are not uint64 (%s, %s)", bPerL2Tx.String(), bPerL1CalldataUnit.String()) - return - } - - perL2TxU64 := bPerL2Tx.Uint64() - perL1CalldataUnitU64 := bPerL1CalldataUnit.Uint64() - if perL2TxU64 > math.MaxUint32 || perL1CalldataUnitU64 > math.MaxUint32 { - err = fmt.Errorf("returned integers are not uint32 (%d, %d)", perL2TxU64, perL1CalldataUnitU64) - return - } - perL2Tx = uint32(perL2TxU64) - perL1CalldataUnit = uint32(perL1CalldataUnitU64) - return -} diff --git a/core/chains/evm/gas/arbitrum_estimator_test.go b/core/chains/evm/gas/arbitrum_estimator_test.go index 3c46b466e87..54d7fc333e3 100644 --- a/core/chains/evm/gas/arbitrum_estimator_test.go +++ b/core/chains/evm/gas/arbitrum_estimator_test.go @@ -19,6 +19,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas/mocks" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas/rollups" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" ) @@ -52,9 +53,10 @@ func TestArbitrumEstimator(t *testing.T) { var bumpMin = assets.NewWei(big.NewInt(1)) t.Run("calling GetLegacyGas on unstarted estimator returns error", func(t *testing.T) { - rpcClient := mocks.NewRPCClient(t) - ethClient := mocks.NewETHClient(t) - o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{}, rpcClient, ethClient) + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollups.NewArbitrumL1GasOracle(logger.Test(t), feeEstimatorClient) + + o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{}, feeEstimatorClient, l1Oracle) _, _, err := o.GetLegacyGas(testutils.Context(t), calldata, gasLimit, maxGasPrice) assert.EqualError(t, err, "estimator is not started") }) @@ -64,21 +66,22 @@ func TestArbitrumEstimator(t *testing.T) { zeros.Write(common.BigToHash(big.NewInt(0)).Bytes()) zeros.Write(common.BigToHash(big.NewInt(123455)).Bytes()) t.Run("calling GetLegacyGas on started estimator returns estimates", func(t *testing.T) { - rpcClient := mocks.NewRPCClient(t) - ethClient := mocks.NewETHClient(t) - rpcClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollups.NewArbitrumL1GasOracle(logger.Test(t), feeEstimatorClient) + + feeEstimatorClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { res := args.Get(1).(*hexutil.Big) (*big.Int)(res).SetInt64(42) }) - ethClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { + feeEstimatorClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { callMsg := args.Get(1).(ethereum.CallMsg) blockNumber := args.Get(2).(*big.Int) - assert.Equal(t, gas.ArbGasInfoAddress, callMsg.To.String()) - assert.Equal(t, gas.ArbGasInfo_getPricesInArbGas, fmt.Sprintf("%x", callMsg.Data)) + assert.Equal(t, rollups.ArbGasInfoAddress, callMsg.To.String()) + assert.Equal(t, rollups.ArbGasInfo_getPricesInArbGas, fmt.Sprintf("%x", callMsg.Data)) assert.Equal(t, big.NewInt(-1), blockNumber) }).Return(zeros.Bytes(), nil) - o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{v: maxGasLimit, bumpPercent: bumpPercent, bumpMin: bumpMin}, rpcClient, ethClient) + o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{v: maxGasLimit, bumpPercent: bumpPercent, bumpMin: bumpMin}, feeEstimatorClient, l1Oracle) servicetest.RunHealthy(t, o) gasPrice, chainSpecificGasLimit, err := o.GetLegacyGas(testutils.Context(t), calldata, gasLimit, maxGasPrice) require.NoError(t, err) @@ -88,19 +91,20 @@ func TestArbitrumEstimator(t *testing.T) { }) t.Run("gas price is lower than user specified max gas price", func(t *testing.T) { - client := mocks.NewRPCClient(t) - ethClient := mocks.NewETHClient(t) - o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{}, client, ethClient) + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollups.NewArbitrumL1GasOracle(logger.Test(t), feeEstimatorClient) + + o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{}, feeEstimatorClient, l1Oracle) - client.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { + feeEstimatorClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { res := args.Get(1).(*hexutil.Big) (*big.Int)(res).SetInt64(42) }) - ethClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { + feeEstimatorClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { callMsg := args.Get(1).(ethereum.CallMsg) blockNumber := args.Get(2).(*big.Int) - assert.Equal(t, gas.ArbGasInfoAddress, callMsg.To.String()) - assert.Equal(t, gas.ArbGasInfo_getPricesInArbGas, fmt.Sprintf("%x", callMsg.Data)) + assert.Equal(t, rollups.ArbGasInfoAddress, callMsg.To.String()) + assert.Equal(t, rollups.ArbGasInfo_getPricesInArbGas, fmt.Sprintf("%x", callMsg.Data)) assert.Equal(t, big.NewInt(-1), blockNumber) }).Return(zeros.Bytes(), nil) @@ -113,19 +117,20 @@ func TestArbitrumEstimator(t *testing.T) { }) t.Run("gas price is lower than global max gas price", func(t *testing.T) { - ethClient := mocks.NewETHClient(t) - client := mocks.NewRPCClient(t) - o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{}, client, ethClient) + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollups.NewArbitrumL1GasOracle(logger.Test(t), feeEstimatorClient) - client.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { + o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{}, feeEstimatorClient, l1Oracle) + + feeEstimatorClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { res := args.Get(1).(*hexutil.Big) (*big.Int)(res).SetInt64(120) }) - ethClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { + feeEstimatorClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { callMsg := args.Get(1).(ethereum.CallMsg) blockNumber := args.Get(2).(*big.Int) - assert.Equal(t, gas.ArbGasInfoAddress, callMsg.To.String()) - assert.Equal(t, gas.ArbGasInfo_getPricesInArbGas, fmt.Sprintf("%x", callMsg.Data)) + assert.Equal(t, rollups.ArbGasInfoAddress, callMsg.To.String()) + assert.Equal(t, rollups.ArbGasInfo_getPricesInArbGas, fmt.Sprintf("%x", callMsg.Data)) assert.Equal(t, big.NewInt(-1), blockNumber) }).Return(zeros.Bytes(), nil) @@ -137,24 +142,26 @@ func TestArbitrumEstimator(t *testing.T) { }) t.Run("calling BumpLegacyGas on unstarted arbitrum estimator returns error", func(t *testing.T) { - rpcClient := mocks.NewRPCClient(t) - ethClient := mocks.NewETHClient(t) - o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{}, rpcClient, ethClient) + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollups.NewArbitrumL1GasOracle(logger.Test(t), feeEstimatorClient) + + o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{}, feeEstimatorClient, l1Oracle) _, _, err := o.BumpLegacyGas(testutils.Context(t), assets.NewWeiI(42), gasLimit, assets.NewWeiI(10), nil) assert.EqualError(t, err, "estimator is not started") }) t.Run("calling GetLegacyGas on started estimator if initial call failed returns error", func(t *testing.T) { - client := mocks.NewRPCClient(t) - ethClient := mocks.NewETHClient(t) - o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{}, client, ethClient) + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollups.NewArbitrumL1GasOracle(logger.Test(t), feeEstimatorClient) - client.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(pkgerrors.New("kaboom")) - ethClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { + o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{}, feeEstimatorClient, l1Oracle) + + feeEstimatorClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(pkgerrors.New("kaboom")) + feeEstimatorClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { callMsg := args.Get(1).(ethereum.CallMsg) blockNumber := args.Get(2).(*big.Int) - assert.Equal(t, gas.ArbGasInfoAddress, callMsg.To.String()) - assert.Equal(t, gas.ArbGasInfo_getPricesInArbGas, fmt.Sprintf("%x", callMsg.Data)) + assert.Equal(t, rollups.ArbGasInfoAddress, callMsg.To.String()) + assert.Equal(t, rollups.ArbGasInfo_getPricesInArbGas, fmt.Sprintf("%x", callMsg.Data)) assert.Equal(t, big.NewInt(-1), blockNumber) }).Return(zeros.Bytes(), nil) @@ -165,17 +172,19 @@ func TestArbitrumEstimator(t *testing.T) { }) t.Run("calling GetDynamicFee always returns error", func(t *testing.T) { - rpcClient := mocks.NewRPCClient(t) - ethClient := mocks.NewETHClient(t) - o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{}, rpcClient, ethClient) + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollups.NewArbitrumL1GasOracle(logger.Test(t), feeEstimatorClient) + + o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{}, feeEstimatorClient, l1Oracle) _, err := o.GetDynamicFee(testutils.Context(t), maxGasPrice) assert.EqualError(t, err, "dynamic fees are not implemented for this estimator") }) t.Run("calling BumpDynamicFee always returns error", func(t *testing.T) { - rpcClient := mocks.NewRPCClient(t) - ethClient := mocks.NewETHClient(t) - o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{}, rpcClient, ethClient) + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollups.NewArbitrumL1GasOracle(logger.Test(t), feeEstimatorClient) + + o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{}, feeEstimatorClient, l1Oracle) fee := gas.DynamicFee{ FeeCap: assets.NewWeiI(42), TipCap: assets.NewWeiI(5), @@ -185,9 +194,10 @@ func TestArbitrumEstimator(t *testing.T) { }) t.Run("limit computes", func(t *testing.T) { - rpcClient := mocks.NewRPCClient(t) - ethClient := mocks.NewETHClient(t) - rpcClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollups.NewArbitrumL1GasOracle(logger.Test(t), feeEstimatorClient) + + feeEstimatorClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { res := args.Get(1).(*hexutil.Big) (*big.Int)(res).SetInt64(42) }) @@ -201,15 +211,15 @@ func TestArbitrumEstimator(t *testing.T) { b.Write(common.BigToHash(big.NewInt(perL2Tx)).Bytes()) b.Write(common.BigToHash(big.NewInt(perL1Calldata)).Bytes()) b.Write(common.BigToHash(big.NewInt(123455)).Bytes()) - ethClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { + feeEstimatorClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { callMsg := args.Get(1).(ethereum.CallMsg) blockNumber := args.Get(2).(*big.Int) - assert.Equal(t, gas.ArbGasInfoAddress, callMsg.To.String()) - assert.Equal(t, gas.ArbGasInfo_getPricesInArbGas, fmt.Sprintf("%x", callMsg.Data)) + assert.Equal(t, rollups.ArbGasInfoAddress, callMsg.To.String()) + assert.Equal(t, rollups.ArbGasInfo_getPricesInArbGas, fmt.Sprintf("%x", callMsg.Data)) assert.Equal(t, big.NewInt(-1), blockNumber) }).Return(b.Bytes(), nil) - o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{v: maxGasLimit, bumpPercent: bumpPercent, bumpMin: bumpMin}, rpcClient, ethClient) + o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{v: maxGasLimit, bumpPercent: bumpPercent, bumpMin: bumpMin}, feeEstimatorClient, l1Oracle) servicetest.RunHealthy(t, o) gasPrice, chainSpecificGasLimit, err := o.GetLegacyGas(testutils.Context(t), calldata, gasLimit, maxGasPrice) require.NoError(t, err) @@ -220,9 +230,10 @@ func TestArbitrumEstimator(t *testing.T) { }) t.Run("limit exceeds max", func(t *testing.T) { - rpcClient := mocks.NewRPCClient(t) - ethClient := mocks.NewETHClient(t) - rpcClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollups.NewArbitrumL1GasOracle(logger.Test(t), feeEstimatorClient) + + feeEstimatorClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { res := args.Get(1).(*hexutil.Big) (*big.Int)(res).SetInt64(42) }) @@ -235,15 +246,15 @@ func TestArbitrumEstimator(t *testing.T) { b.Write(common.BigToHash(big.NewInt(perL2Tx)).Bytes()) b.Write(common.BigToHash(big.NewInt(perL1Calldata)).Bytes()) b.Write(common.BigToHash(big.NewInt(123455)).Bytes()) - ethClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { + feeEstimatorClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { callMsg := args.Get(1).(ethereum.CallMsg) blockNumber := args.Get(2).(*big.Int) - assert.Equal(t, gas.ArbGasInfoAddress, callMsg.To.String()) - assert.Equal(t, gas.ArbGasInfo_getPricesInArbGas, fmt.Sprintf("%x", callMsg.Data)) + assert.Equal(t, rollups.ArbGasInfoAddress, callMsg.To.String()) + assert.Equal(t, rollups.ArbGasInfo_getPricesInArbGas, fmt.Sprintf("%x", callMsg.Data)) assert.Equal(t, big.NewInt(-1), blockNumber) }).Return(b.Bytes(), nil) - o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{v: maxGasLimit, bumpPercent: bumpPercent, bumpMin: bumpMin}, rpcClient, ethClient) + o := gas.NewArbitrumEstimator(logger.Test(t), &arbConfig{v: maxGasLimit, bumpPercent: bumpPercent, bumpMin: bumpMin}, feeEstimatorClient, l1Oracle) servicetest.RunHealthy(t, o) gasPrice, chainSpecificGasLimit, err := o.GetLegacyGas(testutils.Context(t), calldata, gasLimit, maxGasPrice) require.Error(t, err, "expected error but got (%s, %d)", gasPrice, chainSpecificGasLimit) diff --git a/core/chains/evm/gas/block_history_estimator.go b/core/chains/evm/gas/block_history_estimator.go index 5fb9c5d7173..8b8c626f725 100644 --- a/core/chains/evm/gas/block_history_estimator.go +++ b/core/chains/evm/gas/block_history_estimator.go @@ -24,7 +24,7 @@ import ( commonfee "github.com/smartcontractkit/chainlink/v2/common/fee" feetypes "github.com/smartcontractkit/chainlink/v2/common/fee/types" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" - evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas/rollups" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" ) @@ -97,7 +97,7 @@ type estimatorGasEstimatorConfig interface { //go:generate mockery --quiet --name Config --output ./mocks/ --case=underscore type BlockHistoryEstimator struct { services.StateMachine - ethClient evmclient.Client + ethClient feeEstimatorClient chainID big.Int config chainConfig eConfig estimatorGasEstimatorConfig @@ -120,13 +120,16 @@ type BlockHistoryEstimator struct { initialFetch atomic.Bool logger logger.SugaredLogger + + l1Oracle rollups.L1Oracle } // NewBlockHistoryEstimator returns a new BlockHistoryEstimator that listens // for new heads and updates the base gas price dynamically based on the // configured percentile of gas prices in that block -func NewBlockHistoryEstimator(lggr logger.Logger, ethClient evmclient.Client, cfg chainConfig, eCfg estimatorGasEstimatorConfig, bhCfg BlockHistoryConfig, chainID big.Int) EvmEstimator { +func NewBlockHistoryEstimator(lggr logger.Logger, ethClient feeEstimatorClient, cfg chainConfig, eCfg estimatorGasEstimatorConfig, bhCfg BlockHistoryConfig, chainID big.Int, l1Oracle rollups.L1Oracle) EvmEstimator { ctx, cancel := context.WithCancel(context.Background()) + b := &BlockHistoryEstimator{ ethClient: ethClient, chainID: chainID, @@ -141,6 +144,7 @@ func NewBlockHistoryEstimator(lggr logger.Logger, ethClient evmclient.Client, cf ctx: ctx, ctxCancel: cancel, logger: logger.Sugared(logger.Named(lggr, "BlockHistoryEstimator")), + l1Oracle: l1Oracle, } return b @@ -230,6 +234,10 @@ func (b *BlockHistoryEstimator) Start(ctx context.Context) error { }) } +func (b *BlockHistoryEstimator) L1Oracle() rollups.L1Oracle { + return b.l1Oracle +} + func (b *BlockHistoryEstimator) Close() error { return b.StopOnce("BlockHistoryEstimator", func() error { b.ctxCancel() diff --git a/core/chains/evm/gas/block_history_estimator_test.go b/core/chains/evm/gas/block_history_estimator_test.go index 5260a22bff3..941b60545ba 100644 --- a/core/chains/evm/gas/block_history_estimator_test.go +++ b/core/chains/evm/gas/block_history_estimator_test.go @@ -24,6 +24,8 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas/rollups" + rollupMocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas/rollups/mocks" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils" ubig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" @@ -42,12 +44,12 @@ func newBlockHistoryConfig() *gas.MockBlockHistoryConfig { return c } -func newBlockHistoryEstimatorWithChainID(t *testing.T, c evmclient.Client, cfg gas.Config, gCfg gas.GasEstimatorConfig, bhCfg gas.BlockHistoryConfig, cid big.Int) gas.EvmEstimator { - return gas.NewBlockHistoryEstimator(logger.Test(t), c, cfg, gCfg, bhCfg, cid) +func newBlockHistoryEstimatorWithChainID(t *testing.T, c evmclient.Client, cfg gas.Config, gCfg gas.GasEstimatorConfig, bhCfg gas.BlockHistoryConfig, cid big.Int, l1Oracle rollups.L1Oracle) gas.EvmEstimator { + return gas.NewBlockHistoryEstimator(logger.Test(t), c, cfg, gCfg, bhCfg, cid, l1Oracle) } -func newBlockHistoryEstimator(t *testing.T, c evmclient.Client, cfg gas.Config, gCfg gas.GasEstimatorConfig, bhCfg gas.BlockHistoryConfig) *gas.BlockHistoryEstimator { - iface := newBlockHistoryEstimatorWithChainID(t, c, cfg, gCfg, bhCfg, cltest.FixtureChainID) +func newBlockHistoryEstimator(t *testing.T, c evmclient.Client, cfg gas.Config, gCfg gas.GasEstimatorConfig, bhCfg gas.BlockHistoryConfig, l1Oracle rollups.L1Oracle) *gas.BlockHistoryEstimator { + iface := newBlockHistoryEstimatorWithChainID(t, c, cfg, gCfg, bhCfg, cltest.FixtureChainID, l1Oracle) return gas.BlockHistoryEstimatorFromInterface(iface) } @@ -77,8 +79,9 @@ func TestBlockHistoryEstimator_Start(t *testing.T) { t.Run("loads initial state", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) h := &evmtypes.Head{Hash: utils.NewHash(), Number: 42, BaseFeePerGas: assets.NewWeiI(420)} ethClient.On("HeadByNumber", mock.Anything, (*big.Int)(nil)).Return(h, nil) @@ -121,8 +124,9 @@ func TestBlockHistoryEstimator_Start(t *testing.T) { cfg2 := gas.NewMockConfig() ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) - bhe := newBlockHistoryEstimator(t, ethClient, cfg2, geCfg2, bhCfg2) + bhe := newBlockHistoryEstimator(t, ethClient, cfg2, geCfg2, bhCfg2, l1Oracle) h := &evmtypes.Head{Hash: utils.NewHash(), Number: 42, BaseFeePerGas: assets.NewWeiI(420)} ethClient.On("HeadByNumber", mock.Anything, (*big.Int)(nil)).Return(h, nil) @@ -154,8 +158,9 @@ func TestBlockHistoryEstimator_Start(t *testing.T) { t.Run("boots even if initial batch call returns nothing", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) h := &evmtypes.Head{Hash: utils.NewHash(), Number: 42} ethClient.On("HeadByNumber", mock.Anything, (*big.Int)(nil)).Return(h, nil) @@ -172,8 +177,9 @@ func TestBlockHistoryEstimator_Start(t *testing.T) { t.Run("starts anyway if fetching latest head fails", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) ethClient.On("HeadByNumber", mock.Anything, (*big.Int)(nil)).Return(nil, pkgerrors.New("something exploded")) @@ -193,8 +199,9 @@ func TestBlockHistoryEstimator_Start(t *testing.T) { t.Run("starts anyway if fetching first fetch fails, but errors on estimation", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) h := &evmtypes.Head{Hash: utils.NewHash(), Number: 42, BaseFeePerGas: assets.NewWeiI(420)} ethClient.On("HeadByNumber", mock.Anything, (*big.Int)(nil)).Return(h, nil) @@ -216,8 +223,9 @@ func TestBlockHistoryEstimator_Start(t *testing.T) { t.Run("returns error if main context is cancelled", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) h := &evmtypes.Head{Hash: utils.NewHash(), Number: 42, BaseFeePerGas: assets.NewWeiI(420)} ethClient.On("HeadByNumber", mock.Anything, (*big.Int)(nil)).Return(h, nil) @@ -232,8 +240,9 @@ func TestBlockHistoryEstimator_Start(t *testing.T) { t.Run("starts anyway even if the fetch context is cancelled due to taking longer than the MaxStartTime", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) h := &evmtypes.Head{Hash: utils.NewHash(), Number: 42, BaseFeePerGas: assets.NewWeiI(420)} ethClient.On("HeadByNumber", mock.Anything, (*big.Int)(nil)).Return(h, nil) @@ -261,8 +270,9 @@ func TestBlockHistoryEstimator_OnNewLongestChain(t *testing.T) { bhCfg := newBlockHistoryConfig() geCfg := &gas.MockGasEstimatorConfig{} geCfg.EIP1559DynamicFeesF = false + l1Oracle := rollupMocks.NewL1Oracle(t) - bhe := newBlockHistoryEstimator(t, nil, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, nil, cfg, geCfg, bhCfg, l1Oracle) assert.Nil(t, gas.GetLatestBaseFee(bhe)) @@ -284,6 +294,8 @@ func TestBlockHistoryEstimator_FetchBlocks(t *testing.T) { t.Run("with history size of 0, errors", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() @@ -295,7 +307,7 @@ func TestBlockHistoryEstimator_FetchBlocks(t *testing.T) { geCfg := &gas.MockGasEstimatorConfig{} geCfg.EIP1559DynamicFeesF = true - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) head := cltest.Head(42) err := bhe.FetchBlocks(testutils.Context(t), head) @@ -305,6 +317,8 @@ func TestBlockHistoryEstimator_FetchBlocks(t *testing.T) { t.Run("with current block height less than block delay does nothing", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() var blockDelay uint16 = 3 @@ -315,7 +329,7 @@ func TestBlockHistoryEstimator_FetchBlocks(t *testing.T) { geCfg := &gas.MockGasEstimatorConfig{} geCfg.EIP1559DynamicFeesF = true - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) for i := -1; i < 3; i++ { head := cltest.Head(i) @@ -327,6 +341,8 @@ func TestBlockHistoryEstimator_FetchBlocks(t *testing.T) { t.Run("with error retrieving blocks returns error", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() var blockDelay uint16 = 3 @@ -338,7 +354,7 @@ func TestBlockHistoryEstimator_FetchBlocks(t *testing.T) { geCfg := &gas.MockGasEstimatorConfig{} geCfg.EIP1559DynamicFeesF = true - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) ethClient.On("BatchCallContext", mock.Anything, mock.Anything).Return(pkgerrors.New("something exploded")) @@ -349,6 +365,8 @@ func TestBlockHistoryEstimator_FetchBlocks(t *testing.T) { t.Run("batch fetches heads and transactions and sets them on the block history estimator instance", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() var blockDelay uint16 @@ -362,7 +380,7 @@ func TestBlockHistoryEstimator_FetchBlocks(t *testing.T) { geCfg := &gas.MockGasEstimatorConfig{} geCfg.EIP1559DynamicFeesF = true - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) b41 := evmtypes.Block{ Number: 41, @@ -443,6 +461,8 @@ func TestBlockHistoryEstimator_FetchBlocks(t *testing.T) { t.Run("does not refetch blocks below EVM.FinalityDepth", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() var blockDelay uint16 @@ -455,7 +475,7 @@ func TestBlockHistoryEstimator_FetchBlocks(t *testing.T) { geCfg := &gas.MockGasEstimatorConfig{} geCfg.EIP1559DynamicFeesF = true - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) b0 := evmtypes.Block{ Number: 0, @@ -506,6 +526,8 @@ func TestBlockHistoryEstimator_FetchBlocks(t *testing.T) { t.Run("replaces blocks on re-org within EVM.FinalityDepth", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() var blockDelay uint16 @@ -518,7 +540,7 @@ func TestBlockHistoryEstimator_FetchBlocks(t *testing.T) { geCfg := &gas.MockGasEstimatorConfig{} geCfg.EIP1559DynamicFeesF = true - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) b0 := evmtypes.Block{ Number: 0, @@ -577,6 +599,8 @@ func TestBlockHistoryEstimator_FetchBlocks(t *testing.T) { t.Run("uses locally cached blocks if they are in the chain", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() var blockDelay uint16 var historySize uint16 = 3 @@ -589,7 +613,7 @@ func TestBlockHistoryEstimator_FetchBlocks(t *testing.T) { geCfg := &gas.MockGasEstimatorConfig{} geCfg.EIP1559DynamicFeesF = true - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) b0 := evmtypes.Block{ Number: 0, @@ -634,6 +658,8 @@ func TestBlockHistoryEstimator_FetchBlocks(t *testing.T) { t.Run("fetches max(BlockHistoryEstimatorCheckInclusionBlocks, BlockHistoryEstimatorBlockHistorySize)", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() var blockDelay uint16 var historySize uint16 = 1 @@ -648,7 +674,7 @@ func TestBlockHistoryEstimator_FetchBlocks(t *testing.T) { geCfg := &gas.MockGasEstimatorConfig{} geCfg.EIP1559DynamicFeesF = true - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) b42 := evmtypes.Block{ Number: 42, @@ -686,6 +712,8 @@ func TestBlockHistoryEstimator_FetchBlocksAndRecalculate_NoEIP1559(t *testing.T) t.Parallel() ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() bhCfg.BlockDelayF = uint16(0) @@ -698,7 +726,7 @@ func TestBlockHistoryEstimator_FetchBlocksAndRecalculate_NoEIP1559(t *testing.T) geCfg.PriceMaxF = assets.NewWeiI(1000) geCfg.PriceMinF = assets.NewWeiI(0) - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) b1 := evmtypes.Block{ Number: 1, @@ -744,6 +772,7 @@ func TestBlockHistoryEstimator_Recalculate_NoEIP1559(t *testing.T) { t.Run("does not crash or set gas price to zero if there are no transactions", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() @@ -753,7 +782,7 @@ func TestBlockHistoryEstimator_Recalculate_NoEIP1559(t *testing.T) { geCfg := &gas.MockGasEstimatorConfig{} geCfg.EIP1559DynamicFeesF = false - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) blocks := []evmtypes.Block{} gas.SetRollingBlockHistory(bhe, blocks) @@ -770,6 +799,8 @@ func TestBlockHistoryEstimator_Recalculate_NoEIP1559(t *testing.T) { t.Run("sets gas price to EVM.GasEstimator.PriceMax if the calculation would otherwise exceed it", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() @@ -780,7 +811,7 @@ func TestBlockHistoryEstimator_Recalculate_NoEIP1559(t *testing.T) { geCfg.PriceMaxF = maxGasPrice geCfg.PriceMinF = minGasPrice - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) blocks := []evmtypes.Block{ { @@ -805,6 +836,8 @@ func TestBlockHistoryEstimator_Recalculate_NoEIP1559(t *testing.T) { t.Run("sets gas price to EVM.GasEstimator.PriceMin if the calculation would otherwise fall below it", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() @@ -815,7 +848,7 @@ func TestBlockHistoryEstimator_Recalculate_NoEIP1559(t *testing.T) { geCfg.PriceMaxF = maxGasPrice geCfg.PriceMinF = minGasPrice - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) blocks := []evmtypes.Block{ { @@ -840,6 +873,8 @@ func TestBlockHistoryEstimator_Recalculate_NoEIP1559(t *testing.T) { t.Run("ignores any transaction with a zero gas limit", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() @@ -850,7 +885,7 @@ func TestBlockHistoryEstimator_Recalculate_NoEIP1559(t *testing.T) { geCfg.PriceMaxF = maxGasPrice geCfg.PriceMinF = minGasPrice - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) b1Hash := utils.NewHash() b2Hash := utils.NewHash() @@ -887,6 +922,8 @@ func TestBlockHistoryEstimator_Recalculate_NoEIP1559(t *testing.T) { t.Run("takes into account zero priced transactions if chain is not Gnosis", func(t *testing.T) { // Because everyone loves free gas! ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() @@ -897,7 +934,7 @@ func TestBlockHistoryEstimator_Recalculate_NoEIP1559(t *testing.T) { geCfg.PriceMaxF = maxGasPrice geCfg.PriceMinF = assets.NewWeiI(0) - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) b1Hash := utils.NewHash() @@ -920,6 +957,8 @@ func TestBlockHistoryEstimator_Recalculate_NoEIP1559(t *testing.T) { t.Run("ignores zero priced transactions only on Gnosis", func(t *testing.T) { ethClient := evmtest.NewEthClientMock(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() @@ -930,7 +969,7 @@ func TestBlockHistoryEstimator_Recalculate_NoEIP1559(t *testing.T) { geCfg.PriceMaxF = maxGasPrice geCfg.PriceMinF = assets.NewWeiI(11) // Has to be set as Gnosis will only ignore transactions below this price - ibhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + ibhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) bhe := gas.BlockHistoryEstimatorFromInterface(ibhe) b1Hash := utils.NewHash() @@ -964,6 +1003,8 @@ func TestBlockHistoryEstimator_Recalculate_NoEIP1559(t *testing.T) { // Seems unlikely we will ever experience gas prices > 9 Petawei on mainnet (praying to the eth Gods 🙏) // But other chains could easily use a different base of account ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() @@ -976,7 +1017,7 @@ func TestBlockHistoryEstimator_Recalculate_NoEIP1559(t *testing.T) { geCfg.PriceMaxF = reasonablyHugeGasPrice geCfg.PriceMinF = assets.NewWeiI(10) - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) unreasonablyHugeGasPrice := assets.NewWeiI(1000000).Mul(big.NewInt(math.MaxInt64)) @@ -1011,6 +1052,8 @@ func TestBlockHistoryEstimator_Recalculate_NoEIP1559(t *testing.T) { t.Run("doesn't panic if gas price is nil (although I'm still unsure how this can happen)", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() @@ -1021,7 +1064,7 @@ func TestBlockHistoryEstimator_Recalculate_NoEIP1559(t *testing.T) { geCfg.PriceMaxF = maxGasPrice geCfg.PriceMinF = assets.NewWeiI(100) - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) b1Hash := utils.NewHash() @@ -1057,6 +1100,7 @@ func TestBlockHistoryEstimator_Recalculate_EIP1559(t *testing.T) { t.Run("does not crash or set gas price to zero if there are no transactions", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() @@ -1066,7 +1110,7 @@ func TestBlockHistoryEstimator_Recalculate_EIP1559(t *testing.T) { geCfg := &gas.MockGasEstimatorConfig{} geCfg.EIP1559DynamicFeesF = true - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) blocks := []evmtypes.Block{} gas.SetRollingBlockHistory(bhe, blocks) @@ -1095,6 +1139,8 @@ func TestBlockHistoryEstimator_Recalculate_EIP1559(t *testing.T) { t.Run("does not set tip higher than EVM.GasEstimator.PriceMax", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() @@ -1106,7 +1152,7 @@ func TestBlockHistoryEstimator_Recalculate_EIP1559(t *testing.T) { geCfg.PriceMinF = assets.NewWeiI(0) geCfg.TipCapMinF = assets.NewWeiI(0) - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) blocks := []evmtypes.Block{ { @@ -1133,6 +1179,8 @@ func TestBlockHistoryEstimator_Recalculate_EIP1559(t *testing.T) { t.Run("sets tip cap to EVM.GasEstimator.TipCapMin if the calculation would otherwise fall below it", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() @@ -1144,7 +1192,7 @@ func TestBlockHistoryEstimator_Recalculate_EIP1559(t *testing.T) { geCfg.PriceMinF = assets.NewWeiI(0) geCfg.TipCapMinF = assets.NewWeiI(10) - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) blocks := []evmtypes.Block{ { @@ -1171,6 +1219,8 @@ func TestBlockHistoryEstimator_Recalculate_EIP1559(t *testing.T) { t.Run("ignores any transaction with a zero gas limit", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() @@ -1182,7 +1232,7 @@ func TestBlockHistoryEstimator_Recalculate_EIP1559(t *testing.T) { geCfg.PriceMinF = assets.NewWeiI(0) geCfg.TipCapMinF = assets.NewWeiI(10) - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) b1Hash := utils.NewHash() b2Hash := utils.NewHash() @@ -1219,6 +1269,8 @@ func TestBlockHistoryEstimator_Recalculate_EIP1559(t *testing.T) { t.Run("respects minimum gas tip cap", func(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() @@ -1230,7 +1282,7 @@ func TestBlockHistoryEstimator_Recalculate_EIP1559(t *testing.T) { geCfg.PriceMinF = assets.NewWeiI(0) geCfg.TipCapMinF = assets.NewWeiI(1) - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) b1Hash := utils.NewHash() @@ -1255,6 +1307,8 @@ func TestBlockHistoryEstimator_Recalculate_EIP1559(t *testing.T) { t.Run("allows to set zero tip cap if minimum allows it", func(t *testing.T) { // Because everyone loves *cheap* gas! ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() @@ -1266,7 +1320,7 @@ func TestBlockHistoryEstimator_Recalculate_EIP1559(t *testing.T) { geCfg.PriceMinF = assets.NewWeiI(0) geCfg.TipCapMinF = assets.NewWeiI(0) - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) b1Hash := utils.NewHash() @@ -1291,12 +1345,14 @@ func TestBlockHistoryEstimator_Recalculate_EIP1559(t *testing.T) { func TestBlockHistoryEstimator_IsUsable(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() geCfg := &gas.MockGasEstimatorConfig{} geCfg.EIP1559DynamicFeesF = true - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) block := evmtypes.Block{ Number: 0, Hash: utils.NewHash(), @@ -1373,13 +1429,15 @@ func TestBlockHistoryEstimator_IsUsable(t *testing.T) { func TestBlockHistoryEstimator_EffectiveTipCap(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() geCfg := &gas.MockGasEstimatorConfig{} geCfg.EIP1559DynamicFeesF = true - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) block := evmtypes.Block{ Number: 0, @@ -1433,13 +1491,15 @@ func TestBlockHistoryEstimator_EffectiveTipCap(t *testing.T) { func TestBlockHistoryEstimator_EffectiveGasPrice(t *testing.T) { ethClient := evmtest.NewEthClientMockWithDefaultChain(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() geCfg := &gas.MockGasEstimatorConfig{} geCfg.EIP1559DynamicFeesF = false - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) block := evmtypes.Block{ Number: 0, @@ -1772,6 +1832,7 @@ func TestBlockHistoryEstimator_EIP1559Block_Unmarshal(t *testing.T) { func TestBlockHistoryEstimator_GetLegacyGas(t *testing.T) { t.Parallel() + l1Oracle := rollupMocks.NewL1Oracle(t) cfg := gas.NewMockConfig() bhCfg := newBlockHistoryConfig() @@ -1786,7 +1847,7 @@ func TestBlockHistoryEstimator_GetLegacyGas(t *testing.T) { geCfg.PriceMaxF = maxGasPrice geCfg.PriceMinF = assets.NewWeiI(0) - bhe := newBlockHistoryEstimator(t, nil, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, nil, cfg, geCfg, bhCfg, l1Oracle) blocks := []evmtypes.Block{ { @@ -1830,7 +1891,7 @@ func TestBlockHistoryEstimator_GetLegacyGas(t *testing.T) { geCfg.EIP1559DynamicFeesF = false - bhe = newBlockHistoryEstimator(t, nil, cfg, geCfg, bhCfg) + bhe = newBlockHistoryEstimator(t, nil, cfg, geCfg, bhCfg, l1Oracle) gas.SetRollingBlockHistory(bhe, blocks) bhe.Recalculate(cltest.Head(1)) gas.SimulateStart(t, bhe) @@ -1867,7 +1928,9 @@ func TestBlockHistoryEstimator_UseDefaultPriceAsFallback(t *testing.T) { geCfg.PriceDefaultF = assets.NewWeiI(100) ethClient := evmtest.NewEthClientMockWithDefaultChain(t) - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + l1Oracle := rollupMocks.NewL1Oracle(t) + + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) h := &evmtypes.Head{Hash: utils.NewHash(), Number: 42, BaseFeePerGas: nil} ethClient.On("HeadByNumber", mock.Anything, (*big.Int)(nil)).Return(h, nil) @@ -1918,7 +1981,9 @@ func TestBlockHistoryEstimator_UseDefaultPriceAsFallback(t *testing.T) { geCfg.BumpThresholdF = uint64(1) ethClient := evmtest.NewEthClientMockWithDefaultChain(t) - bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg) + l1Oracle := rollupMocks.NewL1Oracle(t) + + bhe := newBlockHistoryEstimator(t, ethClient, cfg, geCfg, bhCfg, l1Oracle) h := &evmtypes.Head{Hash: utils.NewHash(), Number: 42, BaseFeePerGas: assets.NewWeiI(40)} ethClient.On("HeadByNumber", mock.Anything, (*big.Int)(nil)).Return(h, nil) @@ -1967,7 +2032,9 @@ func TestBlockHistoryEstimator_GetDynamicFee(t *testing.T) { geCfg.TipCapMinF = assets.NewWeiI(0) geCfg.PriceMinF = assets.NewWeiI(0) - bhe := newBlockHistoryEstimator(t, nil, cfg, geCfg, bhCfg) + l1Oracle := rollupMocks.NewL1Oracle(t) + + bhe := newBlockHistoryEstimator(t, nil, cfg, geCfg, bhCfg, l1Oracle) blocks := []evmtypes.Block{ { @@ -2065,9 +2132,10 @@ func TestBlockHistoryEstimator_CheckConnectivity(t *testing.T) { lggr, obs := logger.TestObserved(t, zapcore.DebugLevel) geCfg := &gas.MockGasEstimatorConfig{} geCfg.EIP1559DynamicFeesF = false + l1Oracle := rollupMocks.NewL1Oracle(t) bhe := gas.BlockHistoryEstimatorFromInterface( - gas.NewBlockHistoryEstimator(lggr, nil, cfg, geCfg, bhCfg, *testutils.NewRandomEVMChainID()), + gas.NewBlockHistoryEstimator(lggr, nil, cfg, geCfg, bhCfg, *testutils.NewRandomEVMChainID(), l1Oracle), ) attempts := []gas.EvmPriorAttempt{ @@ -2365,8 +2433,9 @@ func TestBlockHistoryEstimator_Bumps(t *testing.T) { geCfg.BumpPercentF = 10 geCfg.BumpMinF = assets.NewWeiI(150) geCfg.PriceMaxF = maxGasPrice + l1Oracle := rollupMocks.NewL1Oracle(t) - bhe := newBlockHistoryEstimator(t, nil, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, nil, cfg, geCfg, bhCfg, l1Oracle) b1 := evmtypes.Block{ Number: 1, @@ -2394,8 +2463,9 @@ func TestBlockHistoryEstimator_Bumps(t *testing.T) { geCfg.BumpPercentF = 10 geCfg.BumpMinF = assets.NewWeiI(150) geCfg.PriceMaxF = maxGasPrice + l1Oracle := rollupMocks.NewL1Oracle(t) - bhe := newBlockHistoryEstimator(t, nil, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, nil, cfg, geCfg, bhCfg, l1Oracle) t.Run("ignores nil current gas price", func(t *testing.T) { gasPrice, gasLimit, err := bhe.BumpLegacyGas(testutils.Context(t), assets.NewWeiI(42), 100000, maxGasPrice, nil) @@ -2475,8 +2545,9 @@ func TestBlockHistoryEstimator_Bumps(t *testing.T) { geCfg.BumpPercentF = 10 geCfg.BumpMinF = assets.NewWeiI(150) geCfg.PriceMaxF = maxGasPrice + l1Oracle := rollupMocks.NewL1Oracle(t) - bhe := newBlockHistoryEstimator(t, nil, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, nil, cfg, geCfg, bhCfg, l1Oracle) b1 := evmtypes.Block{ BaseFeePerGas: assets.NewWeiI(1), @@ -2508,8 +2579,9 @@ func TestBlockHistoryEstimator_Bumps(t *testing.T) { geCfg.BumpMinF = assets.NewWeiI(150) geCfg.PriceMaxF = maxGasPrice geCfg.TipCapDefaultF = assets.NewWeiI(52) + l1Oracle := rollupMocks.NewL1Oracle(t) - bhe := newBlockHistoryEstimator(t, nil, cfg, geCfg, bhCfg) + bhe := newBlockHistoryEstimator(t, nil, cfg, geCfg, bhCfg, l1Oracle) t.Run("when current tip cap is nil", func(t *testing.T) { originalFee := gas.DynamicFee{FeeCap: assets.NewWeiI(100), TipCap: assets.NewWeiI(25)} diff --git a/core/chains/evm/gas/cmd/arbgas/main.go b/core/chains/evm/gas/cmd/arbgas/main.go deleted file mode 100644 index dc107a50b52..00000000000 --- a/core/chains/evm/gas/cmd/arbgas/main.go +++ /dev/null @@ -1,85 +0,0 @@ -// arbgas takes a single URL argument and prints the result of three GetLegacyGas calls to the Arbitrum gas estimator. -package main - -import ( - "context" - "fmt" - "log" - "os" - - "github.com/ethereum/go-ethereum/ethclient" - "github.com/ethereum/go-ethereum/rpc" - - "github.com/smartcontractkit/chainlink-common/pkg/logger" - feetypes "github.com/smartcontractkit/chainlink/v2/common/fee/types" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas" -) - -func main() { - if l := len(os.Args); l != 2 { - log.Fatal("Expected one URL argument but got", l-1) - } - url := os.Args[1] - lggr, err := logger.New() - if err != nil { - log.Fatal("Failed to create logger:", err) - } - - ctx := context.Background() - withEstimator(ctx, logger.Sugared(lggr), url, func(e gas.EvmEstimator) { - printGetLegacyGas(ctx, e, make([]byte, 10), 500_000, assets.GWei(1)) - printGetLegacyGas(ctx, e, make([]byte, 10), 500_000, assets.GWei(1), feetypes.OptForceRefetch) - printGetLegacyGas(ctx, e, make([]byte, 10), max, assets.GWei(1)) - }) -} - -func printGetLegacyGas(ctx context.Context, e gas.EvmEstimator, calldata []byte, l2GasLimit uint64, maxGasPrice *assets.Wei, opts ...feetypes.Opt) { - price, limit, err := e.GetLegacyGas(ctx, calldata, l2GasLimit, maxGasPrice, opts...) - if err != nil { - log.Println("failed to get legacy gas:", err) - return - } - fmt.Println("Price:", price) - fmt.Println("Limit:", limit) -} - -const max = 50_000_000 - -func withEstimator(ctx context.Context, lggr logger.SugaredLogger, url string, f func(e gas.EvmEstimator)) { - rc, err := rpc.Dial(url) - if err != nil { - log.Fatal(err) - } - ec := ethclient.NewClient(rc) - e := gas.NewArbitrumEstimator(lggr, &config{max: max}, rc, ec) - ctx, cancel := context.WithCancel(ctx) - defer cancel() - err = e.Start(ctx) - if err != nil { - log.Fatal(err) - } - defer lggr.ErrorIfFn(e.Close, "Error closing ArbitrumEstimator") - - f(e) -} - -var _ gas.ArbConfig = &config{} - -type config struct { - max uint64 - bumpPercent uint16 - bumpMin *assets.Wei -} - -func (c *config) LimitMax() uint64 { - return c.max -} - -func (c *config) BumpPercent() uint16 { - return c.bumpPercent -} - -func (c *config) BumpMin() *assets.Wei { - return c.bumpMin -} diff --git a/core/chains/evm/gas/fixed_price_estimator.go b/core/chains/evm/gas/fixed_price_estimator.go index fc65413d375..f4749b093a1 100644 --- a/core/chains/evm/gas/fixed_price_estimator.go +++ b/core/chains/evm/gas/fixed_price_estimator.go @@ -9,6 +9,7 @@ import ( commonfee "github.com/smartcontractkit/chainlink/v2/common/fee" feetypes "github.com/smartcontractkit/chainlink/v2/common/fee/types" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas/rollups" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" ) @@ -18,6 +19,7 @@ type fixedPriceEstimator struct { config fixedPriceEstimatorConfig bhConfig fixedPriceEstimatorBlockHistoryConfig lggr logger.SugaredLogger + l1Oracle rollups.L1Oracle } type bumpConfig interface { LimitMultiplier() float32 @@ -43,8 +45,8 @@ type fixedPriceEstimatorBlockHistoryConfig interface { // NewFixedPriceEstimator returns a new "FixedPrice" estimator which will // always use the config default values for gas prices and limits -func NewFixedPriceEstimator(cfg fixedPriceEstimatorConfig, bhCfg fixedPriceEstimatorBlockHistoryConfig, lggr logger.Logger) EvmEstimator { - return &fixedPriceEstimator{cfg, bhCfg, logger.Sugared(logger.Named(lggr, "FixedPriceEstimator"))} +func NewFixedPriceEstimator(cfg fixedPriceEstimatorConfig, ethClient feeEstimatorClient, bhCfg fixedPriceEstimatorBlockHistoryConfig, lggr logger.Logger, l1Oracle rollups.L1Oracle) EvmEstimator { + return &fixedPriceEstimator{cfg, bhCfg, logger.Sugared(logger.Named(lggr, "FixedPriceEstimator")), l1Oracle} } func (f *fixedPriceEstimator) Start(context.Context) error { @@ -128,6 +130,10 @@ func (f *fixedPriceEstimator) BumpDynamicFee( ) } +func (f *fixedPriceEstimator) L1Oracle() rollups.L1Oracle { + return f.l1Oracle +} + func (f *fixedPriceEstimator) Name() string { return f.lggr.Name() } func (f *fixedPriceEstimator) Ready() error { return nil } func (f *fixedPriceEstimator) HealthReport() map[string]error { return map[string]error{} } diff --git a/core/chains/evm/gas/fixed_price_estimator_test.go b/core/chains/evm/gas/fixed_price_estimator_test.go index c31bd41aeee..9c68f9d2fbc 100644 --- a/core/chains/evm/gas/fixed_price_estimator_test.go +++ b/core/chains/evm/gas/fixed_price_estimator_test.go @@ -9,6 +9,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas" + rollupMocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas/rollups/mocks" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" ) @@ -26,7 +27,9 @@ func Test_FixedPriceEstimator(t *testing.T) { t.Run("GetLegacyGas returns EvmGasPriceDefault from config", func(t *testing.T) { config := &gas.MockGasEstimatorConfig{} - f := gas.NewFixedPriceEstimator(config, &blockHistoryConfig{}, logger.Test(t)) + l1Oracle := rollupMocks.NewL1Oracle(t) + + f := gas.NewFixedPriceEstimator(config, nil, &blockHistoryConfig{}, logger.Test(t), l1Oracle) config.PriceDefaultF = assets.NewWeiI(42) config.PriceMaxF = maxGasPrice @@ -41,7 +44,9 @@ func Test_FixedPriceEstimator(t *testing.T) { config := &gas.MockGasEstimatorConfig{} config.PriceDefaultF = assets.NewWeiI(42) config.PriceMaxF = assets.NewWeiI(35) - f := gas.NewFixedPriceEstimator(config, &blockHistoryConfig{}, logger.Test(t)) + l1Oracle := rollupMocks.NewL1Oracle(t) + + f := gas.NewFixedPriceEstimator(config, nil, &blockHistoryConfig{}, logger.Test(t), l1Oracle) gasPrice, gasLimit, err := f.GetLegacyGas(testutils.Context(t), nil, 100000, assets.NewWeiI(30)) require.NoError(t, err) @@ -53,8 +58,9 @@ func Test_FixedPriceEstimator(t *testing.T) { config := &gas.MockGasEstimatorConfig{} config.PriceDefaultF = assets.NewWeiI(42) config.PriceMaxF = assets.NewWeiI(20) + l1Oracle := rollupMocks.NewL1Oracle(t) - f := gas.NewFixedPriceEstimator(config, &blockHistoryConfig{}, logger.Test(t)) + f := gas.NewFixedPriceEstimator(config, nil, &blockHistoryConfig{}, logger.Test(t), l1Oracle) gasPrice, gasLimit, err := f.GetLegacyGas(testutils.Context(t), nil, 100000, assets.NewWeiI(30)) require.NoError(t, err) assert.Equal(t, 100000, int(gasLimit)) @@ -67,9 +73,10 @@ func Test_FixedPriceEstimator(t *testing.T) { config.PriceMaxF = maxGasPrice config.BumpPercentF = uint16(10) config.BumpMinF = assets.NewWeiI(150) + l1Oracle := rollupMocks.NewL1Oracle(t) lggr := logger.TestSugared(t) - f := gas.NewFixedPriceEstimator(config, &blockHistoryConfig{}, lggr) + f := gas.NewFixedPriceEstimator(config, nil, &blockHistoryConfig{}, lggr, l1Oracle) gasPrice, gasLimit, err := f.BumpLegacyGas(testutils.Context(t), assets.NewWeiI(42), 100000, maxGasPrice, nil) require.NoError(t, err) @@ -87,9 +94,10 @@ func Test_FixedPriceEstimator(t *testing.T) { config.TipCapDefaultF = assets.NewWeiI(52) config.FeeCapDefaultF = assets.NewWeiI(100) config.BumpThresholdF = uint64(3) + l1Oracle := rollupMocks.NewL1Oracle(t) lggr := logger.Test(t) - f := gas.NewFixedPriceEstimator(config, &blockHistoryConfig{}, lggr) + f := gas.NewFixedPriceEstimator(config, nil, &blockHistoryConfig{}, lggr, l1Oracle) fee, err := f.GetDynamicFee(testutils.Context(t), maxGasPrice) require.NoError(t, err) @@ -120,9 +128,10 @@ func Test_FixedPriceEstimator(t *testing.T) { config.TipCapDefaultF = assets.NewWeiI(52) config.BumpMinF = assets.NewWeiI(150) config.BumpPercentF = uint16(10) + l1Oracle := rollupMocks.NewL1Oracle(t) lggr := logger.TestSugared(t) - f := gas.NewFixedPriceEstimator(config, &blockHistoryConfig{}, lggr) + f := gas.NewFixedPriceEstimator(config, nil, &blockHistoryConfig{}, lggr, l1Oracle) originalFee := gas.DynamicFee{FeeCap: assets.NewWeiI(100), TipCap: assets.NewWeiI(25)} fee, err := f.BumpDynamicFee(testutils.Context(t), originalFee, maxGasPrice, nil) diff --git a/core/chains/evm/gas/mocks/eth_client.go b/core/chains/evm/gas/mocks/eth_client.go deleted file mode 100644 index bb0784f8515..00000000000 --- a/core/chains/evm/gas/mocks/eth_client.go +++ /dev/null @@ -1,61 +0,0 @@ -// Code generated by mockery v2.38.0. DO NOT EDIT. - -package mocks - -import ( - context "context" - big "math/big" - - ethereum "github.com/ethereum/go-ethereum" - - mock "github.com/stretchr/testify/mock" -) - -// ETHClient is an autogenerated mock type for the ethClient type -type ETHClient struct { - mock.Mock -} - -// CallContract provides a mock function with given fields: ctx, msg, blockNumber -func (_m *ETHClient) CallContract(ctx context.Context, msg ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) { - ret := _m.Called(ctx, msg, blockNumber) - - if len(ret) == 0 { - panic("no return value specified for CallContract") - } - - var r0 []byte - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, ethereum.CallMsg, *big.Int) ([]byte, error)); ok { - return rf(ctx, msg, blockNumber) - } - if rf, ok := ret.Get(0).(func(context.Context, ethereum.CallMsg, *big.Int) []byte); ok { - r0 = rf(ctx, msg, blockNumber) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]byte) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, ethereum.CallMsg, *big.Int) error); ok { - r1 = rf(ctx, msg, blockNumber) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// NewETHClient creates a new instance of ETHClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewETHClient(t interface { - mock.TestingT - Cleanup(func()) -}) *ETHClient { - mock := ÐClient{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/core/chains/evm/gas/mocks/evm_estimator.go b/core/chains/evm/gas/mocks/evm_estimator.go index a0b6fa62432..600e43a7c69 100644 --- a/core/chains/evm/gas/mocks/evm_estimator.go +++ b/core/chains/evm/gas/mocks/evm_estimator.go @@ -13,6 +13,8 @@ import ( mock "github.com/stretchr/testify/mock" + rollups "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas/rollups" + types "github.com/smartcontractkit/chainlink/v2/common/fee/types" ) @@ -196,6 +198,26 @@ func (_m *EvmEstimator) HealthReport() map[string]error { return r0 } +// L1Oracle provides a mock function with given fields: +func (_m *EvmEstimator) L1Oracle() rollups.L1Oracle { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for L1Oracle") + } + + var r0 rollups.L1Oracle + if rf, ok := ret.Get(0).(func() rollups.L1Oracle); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(rollups.L1Oracle) + } + } + + return r0 +} + // Name provides a mock function with given fields: func (_m *EvmEstimator) Name() string { ret := _m.Called() diff --git a/core/chains/evm/gas/mocks/fee_estimator_client.go b/core/chains/evm/gas/mocks/fee_estimator_client.go new file mode 100644 index 00000000000..50eb17d2dac --- /dev/null +++ b/core/chains/evm/gas/mocks/fee_estimator_client.go @@ -0,0 +1,154 @@ +// Code generated by mockery v2.38.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + big "math/big" + + ethereum "github.com/ethereum/go-ethereum" + + mock "github.com/stretchr/testify/mock" + + rpc "github.com/ethereum/go-ethereum/rpc" + + types "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" +) + +// FeeEstimatorClient is an autogenerated mock type for the feeEstimatorClient type +type FeeEstimatorClient struct { + mock.Mock +} + +// BatchCallContext provides a mock function with given fields: ctx, b +func (_m *FeeEstimatorClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem) error { + ret := _m.Called(ctx, b) + + if len(ret) == 0 { + panic("no return value specified for BatchCallContext") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, []rpc.BatchElem) error); ok { + r0 = rf(ctx, b) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// CallContext provides a mock function with given fields: ctx, result, method, args +func (_m *FeeEstimatorClient) CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, ctx, result, method) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for CallContext") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, interface{}, string, ...interface{}) error); ok { + r0 = rf(ctx, result, method, args...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// CallContract provides a mock function with given fields: ctx, msg, blockNumber +func (_m *FeeEstimatorClient) CallContract(ctx context.Context, msg ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) { + ret := _m.Called(ctx, msg, blockNumber) + + if len(ret) == 0 { + panic("no return value specified for CallContract") + } + + var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, ethereum.CallMsg, *big.Int) ([]byte, error)); ok { + return rf(ctx, msg, blockNumber) + } + if rf, ok := ret.Get(0).(func(context.Context, ethereum.CallMsg, *big.Int) []byte); ok { + r0 = rf(ctx, msg, blockNumber) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, ethereum.CallMsg, *big.Int) error); ok { + r1 = rf(ctx, msg, blockNumber) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ConfiguredChainID provides a mock function with given fields: +func (_m *FeeEstimatorClient) ConfiguredChainID() *big.Int { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ConfiguredChainID") + } + + var r0 *big.Int + if rf, ok := ret.Get(0).(func() *big.Int); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*big.Int) + } + } + + return r0 +} + +// HeadByNumber provides a mock function with given fields: ctx, n +func (_m *FeeEstimatorClient) HeadByNumber(ctx context.Context, n *big.Int) (*types.Head, error) { + ret := _m.Called(ctx, n) + + if len(ret) == 0 { + panic("no return value specified for HeadByNumber") + } + + var r0 *types.Head + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *big.Int) (*types.Head, error)); ok { + return rf(ctx, n) + } + if rf, ok := ret.Get(0).(func(context.Context, *big.Int) *types.Head); ok { + r0 = rf(ctx, n) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*types.Head) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *big.Int) error); ok { + r1 = rf(ctx, n) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewFeeEstimatorClient creates a new instance of FeeEstimatorClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewFeeEstimatorClient(t interface { + mock.TestingT + Cleanup(func()) +}) *FeeEstimatorClient { + mock := &FeeEstimatorClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/chains/evm/gas/mocks/rpc_client.go b/core/chains/evm/gas/mocks/rpc_client.go deleted file mode 100644 index d1262665f66..00000000000 --- a/core/chains/evm/gas/mocks/rpc_client.go +++ /dev/null @@ -1,49 +0,0 @@ -// Code generated by mockery v2.38.0. DO NOT EDIT. - -package mocks - -import ( - context "context" - - mock "github.com/stretchr/testify/mock" -) - -// RPCClient is an autogenerated mock type for the rpcClient type -type RPCClient struct { - mock.Mock -} - -// CallContext provides a mock function with given fields: ctx, result, method, args -func (_m *RPCClient) CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error { - var _ca []interface{} - _ca = append(_ca, ctx, result, method) - _ca = append(_ca, args...) - ret := _m.Called(_ca...) - - if len(ret) == 0 { - panic("no return value specified for CallContext") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, interface{}, string, ...interface{}) error); ok { - r0 = rf(ctx, result, method, args...) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// NewRPCClient creates a new instance of RPCClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewRPCClient(t interface { - mock.TestingT - Cleanup(func()) -}) *RPCClient { - mock := &RPCClient{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/core/chains/evm/gas/models.go b/core/chains/evm/gas/models.go index 04673d5a622..c50e19373f1 100644 --- a/core/chains/evm/gas/models.go +++ b/core/chains/evm/gas/models.go @@ -5,8 +5,10 @@ import ( "fmt" "math/big" + "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/rpc" pkgerrors "github.com/pkg/errors" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -18,7 +20,6 @@ import ( feetypes "github.com/smartcontractkit/chainlink/v2/common/fee/types" "github.com/smartcontractkit/chainlink/v2/common/headtracker" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" - evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" evmconfig "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas/rollups" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/label" @@ -41,8 +42,17 @@ type EvmFeeEstimator interface { GetMaxCost(ctx context.Context, amount assets.Eth, calldata []byte, feeLimit uint64, maxFeePrice *assets.Wei, opts ...feetypes.Opt) (*big.Int, error) } +//go:generate mockery --quiet --name feeEstimatorClient --output ./mocks/ --case=underscore --structname FeeEstimatorClient +type feeEstimatorClient interface { + CallContract(ctx context.Context, msg ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) + BatchCallContext(ctx context.Context, b []rpc.BatchElem) error + CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error + ConfiguredChainID() *big.Int + HeadByNumber(ctx context.Context, n *big.Int) (*evmtypes.Head, error) +} + // NewEstimator returns the estimator for a given config -func NewEstimator(lggr logger.Logger, ethClient evmclient.Client, cfg Config, geCfg evmconfig.GasEstimator) EvmFeeEstimator { +func NewEstimator(lggr logger.Logger, ethClient feeEstimatorClient, cfg Config, geCfg evmconfig.GasEstimator) EvmFeeEstimator { bh := geCfg.BlockHistory() s := geCfg.Mode() lggr.Infow(fmt.Sprintf("Initializing EVM gas estimator in mode: %s", s), @@ -75,27 +85,27 @@ func NewEstimator(lggr logger.Logger, ethClient evmclient.Client, cfg Config, ge switch s { case "Arbitrum": newEstimator = func(l logger.Logger) EvmEstimator { - return NewArbitrumEstimator(lggr, geCfg, ethClient, ethClient) + return NewArbitrumEstimator(lggr, geCfg, ethClient, rollups.NewArbitrumL1GasOracle(lggr, ethClient)) } case "BlockHistory": newEstimator = func(l logger.Logger) EvmEstimator { - return NewBlockHistoryEstimator(lggr, ethClient, cfg, geCfg, bh, *ethClient.ConfiguredChainID()) + return NewBlockHistoryEstimator(lggr, ethClient, cfg, geCfg, bh, *ethClient.ConfiguredChainID(), l1Oracle) } case "FixedPrice": newEstimator = func(l logger.Logger) EvmEstimator { - return NewFixedPriceEstimator(geCfg, bh, lggr) + return NewFixedPriceEstimator(geCfg, ethClient, bh, lggr, l1Oracle) } case "L2Suggested", "SuggestedPrice": newEstimator = func(l logger.Logger) EvmEstimator { - return NewSuggestedPriceEstimator(lggr, ethClient, geCfg) + return NewSuggestedPriceEstimator(lggr, ethClient, geCfg, l1Oracle) } default: lggr.Warnf("GasEstimator: unrecognised mode '%s', falling back to FixedPriceEstimator", s) newEstimator = func(l logger.Logger) EvmEstimator { - return NewFixedPriceEstimator(geCfg, bh, lggr) + return NewFixedPriceEstimator(geCfg, ethClient, bh, lggr, l1Oracle) } } - return NewWrappedEvmEstimator(lggr, newEstimator, df, l1Oracle, geCfg) + return NewEvmFeeEstimator(lggr, newEstimator, df, geCfg) } // DynamicFee encompasses both FeeCap and TipCap for EIP1559 transactions @@ -138,6 +148,8 @@ type EvmEstimator interface { // - be sorted in order from highest price to lowest price // - all be of transaction type 0x2 BumpDynamicFee(ctx context.Context, original DynamicFee, maxGasPriceWei *assets.Wei, attempts []EvmPriorAttempt) (bumped DynamicFee, err error) + + L1Oracle() rollups.L1Oracle } var _ feetypes.Fee = (*EvmFee)(nil) @@ -159,53 +171,53 @@ func (fee EvmFee) ValidDynamic() bool { return fee.DynamicFeeCap != nil && fee.DynamicTipCap != nil } -// WrappedEvmEstimator provides a struct that wraps the EVM specific dynamic and legacy estimators into one estimator that conforms to the generic FeeEstimator -type WrappedEvmEstimator struct { +// evmFeeEstimator provides a struct that wraps the EVM specific dynamic and legacy estimators into one estimator that conforms to the generic FeeEstimator +type evmFeeEstimator struct { services.StateMachine lggr logger.Logger EvmEstimator EIP1559Enabled bool - l1Oracle rollups.L1Oracle geCfg GasEstimatorConfig } -var _ EvmFeeEstimator = (*WrappedEvmEstimator)(nil) +var _ EvmFeeEstimator = (*evmFeeEstimator)(nil) -func NewWrappedEvmEstimator(lggr logger.Logger, newEstimator func(logger.Logger) EvmEstimator, eip1559Enabled bool, l1Oracle rollups.L1Oracle, geCfg GasEstimatorConfig) EvmFeeEstimator { +func NewEvmFeeEstimator(lggr logger.Logger, newEstimator func(logger.Logger) EvmEstimator, eip1559Enabled bool, geCfg GasEstimatorConfig) EvmFeeEstimator { lggr = logger.Named(lggr, "WrappedEvmEstimator") - return &WrappedEvmEstimator{ + return &evmFeeEstimator{ lggr: lggr, EvmEstimator: newEstimator(lggr), EIP1559Enabled: eip1559Enabled, - l1Oracle: l1Oracle, geCfg: geCfg, } } -func (e *WrappedEvmEstimator) Name() string { +func (e *evmFeeEstimator) Name() string { return e.lggr.Name() } -func (e *WrappedEvmEstimator) Start(ctx context.Context) error { +func (e *evmFeeEstimator) Start(ctx context.Context) error { return e.StartOnce(e.Name(), func() error { if err := e.EvmEstimator.Start(ctx); err != nil { return pkgerrors.Wrap(err, "failed to start EVMEstimator") } - if e.l1Oracle != nil { - if err := e.l1Oracle.Start(ctx); err != nil { + l1Oracle := e.L1Oracle() + if l1Oracle != nil { + if err := l1Oracle.Start(ctx); err != nil { return pkgerrors.Wrap(err, "failed to start L1Oracle") } } return nil }) } -func (e *WrappedEvmEstimator) Close() error { +func (e *evmFeeEstimator) Close() error { return e.StopOnce(e.Name(), func() error { var errEVM, errOracle error errEVM = pkgerrors.Wrap(e.EvmEstimator.Close(), "failed to stop EVMEstimator") - if e.l1Oracle != nil { - errOracle = pkgerrors.Wrap(e.l1Oracle.Close(), "failed to stop L1Oracle") + l1Oracle := e.L1Oracle() + if l1Oracle != nil { + errOracle = pkgerrors.Wrap(l1Oracle.Close(), "failed to stop L1Oracle") } if errEVM != nil { @@ -215,12 +227,13 @@ func (e *WrappedEvmEstimator) Close() error { }) } -func (e *WrappedEvmEstimator) Ready() error { +func (e *evmFeeEstimator) Ready() error { var errEVM, errOracle error errEVM = e.EvmEstimator.Ready() - if e.l1Oracle != nil { - errOracle = e.l1Oracle.Ready() + l1Oracle := e.L1Oracle() + if l1Oracle != nil { + errOracle = l1Oracle.Ready() } if errEVM != nil { @@ -229,21 +242,23 @@ func (e *WrappedEvmEstimator) Ready() error { return errOracle } -func (e *WrappedEvmEstimator) HealthReport() map[string]error { +func (e *evmFeeEstimator) HealthReport() map[string]error { report := map[string]error{e.Name(): e.Healthy()} services.CopyHealth(report, e.EvmEstimator.HealthReport()) - if e.l1Oracle != nil { - services.CopyHealth(report, e.l1Oracle.HealthReport()) + + l1Oracle := e.L1Oracle() + if l1Oracle != nil { + services.CopyHealth(report, l1Oracle.HealthReport()) } return report } -func (e *WrappedEvmEstimator) L1Oracle() rollups.L1Oracle { - return e.l1Oracle +func (e *evmFeeEstimator) L1Oracle() rollups.L1Oracle { + return e.EvmEstimator.L1Oracle() } -func (e *WrappedEvmEstimator) GetFee(ctx context.Context, calldata []byte, feeLimit uint64, maxFeePrice *assets.Wei, opts ...feetypes.Opt) (fee EvmFee, chainSpecificFeeLimit uint64, err error) { +func (e *evmFeeEstimator) GetFee(ctx context.Context, calldata []byte, feeLimit uint64, maxFeePrice *assets.Wei, opts ...feetypes.Opt) (fee EvmFee, chainSpecificFeeLimit uint64, err error) { // get dynamic fee if e.EIP1559Enabled { var dynamicFee DynamicFee @@ -267,7 +282,7 @@ func (e *WrappedEvmEstimator) GetFee(ctx context.Context, calldata []byte, feeLi return } -func (e *WrappedEvmEstimator) GetMaxCost(ctx context.Context, amount assets.Eth, calldata []byte, feeLimit uint64, maxFeePrice *assets.Wei, opts ...feetypes.Opt) (*big.Int, error) { +func (e *evmFeeEstimator) GetMaxCost(ctx context.Context, amount assets.Eth, calldata []byte, feeLimit uint64, maxFeePrice *assets.Wei, opts ...feetypes.Opt) (*big.Int, error) { fees, gasLimit, err := e.GetFee(ctx, calldata, feeLimit, maxFeePrice, opts...) if err != nil { return nil, err @@ -285,7 +300,7 @@ func (e *WrappedEvmEstimator) GetMaxCost(ctx context.Context, amount assets.Eth, return amountWithFees, nil } -func (e *WrappedEvmEstimator) BumpFee(ctx context.Context, originalFee EvmFee, feeLimit uint64, maxFeePrice *assets.Wei, attempts []EvmPriorAttempt) (bumpedFee EvmFee, chainSpecificFeeLimit uint64, err error) { +func (e *evmFeeEstimator) BumpFee(ctx context.Context, originalFee EvmFee, feeLimit uint64, maxFeePrice *assets.Wei, attempts []EvmPriorAttempt) (bumpedFee EvmFee, chainSpecificFeeLimit uint64, err error) { // validate only 1 fee type is present if (!originalFee.ValidDynamic() && originalFee.Legacy == nil) || (originalFee.ValidDynamic() && originalFee.Legacy != nil) { err = pkgerrors.New("only one dynamic or legacy fee can be defined") diff --git a/core/chains/evm/gas/models_test.go b/core/chains/evm/gas/models_test.go index ec9542b4040..722beb8021a 100644 --- a/core/chains/evm/gas/models_test.go +++ b/core/chains/evm/gas/models_test.go @@ -10,9 +10,11 @@ import ( "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink/v2/common/config" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas/mocks" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas/rollups" rollupMocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas/rollups/mocks" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" ) @@ -49,14 +51,24 @@ func TestWrappedEvmEstimator(t *testing.T) { // L1Oracle returns the correct L1Oracle interface t.Run("L1Oracle", func(t *testing.T) { lggr := logger.Test(t) + + evmEstimator := mocks.NewEvmEstimator(t) + evmEstimator.On("L1Oracle").Return(nil).Once() + + getEst := func(logger.Logger) gas.EvmEstimator { return evmEstimator } + // expect nil - estimator := gas.NewWrappedEvmEstimator(lggr, getRootEst, false, nil, nil) + estimator := gas.NewEvmFeeEstimator(lggr, getEst, false, nil) l1Oracle := estimator.L1Oracle() + assert.Nil(t, l1Oracle) // expect l1Oracle - oracle := rollupMocks.NewL1Oracle(t) - estimator = gas.NewWrappedEvmEstimator(lggr, getRootEst, false, oracle, geCfg) + oracle := rollups.NewL1GasOracle(lggr, nil, config.ChainOptimismBedrock) + // cast oracle to L1Oracle interface + estimator = gas.NewEvmFeeEstimator(lggr, getEst, false, geCfg) + + evmEstimator.On("L1Oracle").Return(oracle).Once() l1Oracle = estimator.L1Oracle() assert.Equal(t, oracle, l1Oracle) }) @@ -66,7 +78,7 @@ func TestWrappedEvmEstimator(t *testing.T) { lggr := logger.Test(t) // expect legacy fee data dynamicFees := false - estimator := gas.NewWrappedEvmEstimator(lggr, getRootEst, dynamicFees, nil, geCfg) + estimator := gas.NewEvmFeeEstimator(lggr, getRootEst, dynamicFees, geCfg) fee, max, err := estimator.GetFee(ctx, nil, 0, nil) require.NoError(t, err) assert.Equal(t, uint64(float32(gasLimit)*limitMultiplier), max) @@ -76,7 +88,7 @@ func TestWrappedEvmEstimator(t *testing.T) { // expect dynamic fee data dynamicFees = true - estimator = gas.NewWrappedEvmEstimator(lggr, getRootEst, dynamicFees, nil, geCfg) + estimator = gas.NewEvmFeeEstimator(lggr, getRootEst, dynamicFees, geCfg) fee, max, err = estimator.GetFee(ctx, nil, gasLimit, nil) require.NoError(t, err) assert.Equal(t, uint64(float32(gasLimit)*limitMultiplier), max) @@ -89,7 +101,7 @@ func TestWrappedEvmEstimator(t *testing.T) { t.Run("BumpFee", func(t *testing.T) { lggr := logger.Test(t) dynamicFees := false - estimator := gas.NewWrappedEvmEstimator(lggr, getRootEst, dynamicFees, nil, geCfg) + estimator := gas.NewEvmFeeEstimator(lggr, getRootEst, dynamicFees, geCfg) // expect legacy fee data fee, max, err := estimator.BumpFee(ctx, gas.EvmFee{Legacy: assets.NewWeiI(0)}, 0, nil, nil) @@ -127,7 +139,7 @@ func TestWrappedEvmEstimator(t *testing.T) { // expect legacy fee data dynamicFees := false - estimator := gas.NewWrappedEvmEstimator(lggr, getRootEst, dynamicFees, nil, geCfg) + estimator := gas.NewEvmFeeEstimator(lggr, getRootEst, dynamicFees, geCfg) total, err := estimator.GetMaxCost(ctx, val, nil, gasLimit, nil) require.NoError(t, err) fee := new(big.Int).Mul(legacyFee.ToInt(), big.NewInt(int64(gasLimit))) @@ -136,7 +148,7 @@ func TestWrappedEvmEstimator(t *testing.T) { // expect dynamic fee data dynamicFees = true - estimator = gas.NewWrappedEvmEstimator(lggr, getRootEst, dynamicFees, nil, geCfg) + estimator = gas.NewEvmFeeEstimator(lggr, getRootEst, dynamicFees, geCfg) total, err = estimator.GetMaxCost(ctx, val, nil, gasLimit, nil) require.NoError(t, err) fee = new(big.Int).Mul(dynamicFee.FeeCap.ToInt(), big.NewInt(int64(gasLimit))) @@ -147,13 +159,12 @@ func TestWrappedEvmEstimator(t *testing.T) { t.Run("Name", func(t *testing.T) { lggr := logger.Test(t) - oracle := rollupMocks.NewL1Oracle(t) evmEstimator := mocks.NewEvmEstimator(t) evmEstimator.On("Name").Return(mockEvmEstimatorName, nil).Once() - estimator := gas.NewWrappedEvmEstimator(lggr, func(logger.Logger) gas.EvmEstimator { + estimator := gas.NewEvmFeeEstimator(lggr, func(logger.Logger) gas.EvmEstimator { return evmEstimator - }, false, oracle, geCfg) + }, false, geCfg) require.Equal(t, mockEstimatorName, estimator.Name()) require.Equal(t, mockEvmEstimatorName, evmEstimator.Name()) @@ -170,13 +181,17 @@ func TestWrappedEvmEstimator(t *testing.T) { oracle.On("Close").Return(nil).Once() getEst := func(logger.Logger) gas.EvmEstimator { return evmEstimator } - estimator := gas.NewWrappedEvmEstimator(lggr, getEst, false, nil, geCfg) + evmEstimator.On("L1Oracle", mock.Anything).Return(nil).Twice() + + estimator := gas.NewEvmFeeEstimator(lggr, getEst, false, geCfg) err := estimator.Start(ctx) require.NoError(t, err) err = estimator.Close() require.NoError(t, err) - estimator = gas.NewWrappedEvmEstimator(lggr, getEst, false, oracle, geCfg) + evmEstimator.On("L1Oracle", mock.Anything).Return(oracle).Twice() + + estimator = gas.NewEvmFeeEstimator(lggr, getEst, false, geCfg) err = estimator.Start(ctx) require.NoError(t, err) err = estimator.Close() @@ -188,15 +203,16 @@ func TestWrappedEvmEstimator(t *testing.T) { evmEstimator := mocks.NewEvmEstimator(t) oracle := rollupMocks.NewL1Oracle(t) + evmEstimator.On("L1Oracle").Return(oracle).Twice() evmEstimator.On("Ready").Return(nil).Twice() - oracle.On("Ready").Return(nil).Once() + oracle.On("Ready").Return(nil).Twice() getEst := func(logger.Logger) gas.EvmEstimator { return evmEstimator } - estimator := gas.NewWrappedEvmEstimator(lggr, getEst, false, nil, geCfg) + estimator := gas.NewEvmFeeEstimator(lggr, getEst, false, geCfg) err := estimator.Ready() require.NoError(t, err) - estimator = gas.NewWrappedEvmEstimator(lggr, getEst, false, oracle, geCfg) + estimator = gas.NewEvmFeeEstimator(lggr, getEst, false, geCfg) err = estimator.Ready() require.NoError(t, err) }) @@ -211,17 +227,21 @@ func TestWrappedEvmEstimator(t *testing.T) { oracleKey := "oracle" oracleError := pkgerrors.New("oracle error") + evmEstimator.On("L1Oracle").Return(nil).Once() evmEstimator.On("HealthReport").Return(map[string]error{evmEstimatorKey: evmEstimatorError}).Twice() + oracle.On("HealthReport").Return(map[string]error{oracleKey: oracleError}).Once() getEst := func(logger.Logger) gas.EvmEstimator { return evmEstimator } - estimator := gas.NewWrappedEvmEstimator(lggr, getEst, false, nil, geCfg) + estimator := gas.NewEvmFeeEstimator(lggr, getEst, false, geCfg) report := estimator.HealthReport() require.True(t, pkgerrors.Is(report[evmEstimatorKey], evmEstimatorError)) require.Nil(t, report[oracleKey]) require.NotNil(t, report[mockEstimatorName]) - estimator = gas.NewWrappedEvmEstimator(lggr, getEst, false, oracle, geCfg) + evmEstimator.On("L1Oracle").Return(oracle).Once() + + estimator = gas.NewEvmFeeEstimator(lggr, getEst, false, geCfg) report = estimator.HealthReport() require.True(t, pkgerrors.Is(report[evmEstimatorKey], evmEstimatorError)) require.True(t, pkgerrors.Is(report[oracleKey], oracleError)) diff --git a/core/chains/evm/gas/rollups/arbitrum_l1_oracle.go b/core/chains/evm/gas/rollups/arbitrum_l1_oracle.go new file mode 100644 index 00000000000..d0b4c5808ad --- /dev/null +++ b/core/chains/evm/gas/rollups/arbitrum_l1_oracle.go @@ -0,0 +1,300 @@ +package rollups + +import ( + "context" + "fmt" + "math" + "math/big" + "strings" + "sync" + "time" + + "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/common" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/utils" + + gethtypes "github.com/ethereum/go-ethereum/core/types" + + "github.com/smartcontractkit/chainlink/v2/common/client" + "github.com/smartcontractkit/chainlink/v2/common/config" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" + evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" +) + +type ArbL1GasOracle interface { + L1Oracle + GetPricesInArbGas() (perL2Tx uint32, perL1CalldataUnit uint32, err error) +} + +// Reads L2-specific precompiles and caches the l1GasPrice set by the L2. +type arbitrumL1Oracle struct { + services.StateMachine + client l1OracleClient + pollPeriod time.Duration + logger logger.SugaredLogger + chainType config.ChainType + + l1GasPriceAddress string + gasPriceMethod string + l1GasPriceMethodAbi abi.ABI + l1GasPriceMu sync.RWMutex + l1GasPrice priceEntry + + l1GasCostAddress string + gasCostMethod string + l1GasCostMethodAbi abi.ABI + + chInitialised chan struct{} + chStop services.StopChan + chDone chan struct{} +} + +const ( + // ArbGasInfoAddress is the address of the "Precompiled contract that exists in every Arbitrum chain." + // https://github.com/OffchainLabs/nitro/blob/f7645453cfc77bf3e3644ea1ac031eff629df325/contracts/src/precompiles/ArbGasInfo.sol + ArbGasInfoAddress = "0x000000000000000000000000000000000000006C" + // ArbGasInfo_getL1BaseFeeEstimate is the a hex encoded call to: + // `function getL1BaseFeeEstimate() external view returns (uint256);` + ArbGasInfo_getL1BaseFeeEstimate = "getL1BaseFeeEstimate" + // NodeInterfaceAddress is the address of the precompiled contract that is only available through RPC + // https://github.com/OffchainLabs/nitro/blob/e815395d2e91fb17f4634cad72198f6de79c6e61/nodeInterface/NodeInterface.go#L37 + ArbNodeInterfaceAddress = "0x00000000000000000000000000000000000000C8" + // ArbGasInfo_getPricesInArbGas is the a hex encoded call to: + // `function gasEstimateL1Component(address to, bool contractCreation, bytes calldata data) external payable returns (uint64 gasEstimateForL1, uint256 baseFee, uint256 l1BaseFeeEstimate);` + ArbNodeInterface_gasEstimateL1Component = "gasEstimateL1Component" + // ArbGasInfo_getPricesInArbGas is the a hex encoded call to: + // `function getPricesInArbGas() external view returns (uint256, uint256, uint256);` + ArbGasInfo_getPricesInArbGas = "02199f34" +) + +func NewArbitrumL1GasOracle(lggr logger.Logger, ethClient l1OracleClient) *arbitrumL1Oracle { + var l1GasPriceAddress, gasPriceMethod, l1GasCostAddress, gasCostMethod string + var l1GasPriceMethodAbi, l1GasCostMethodAbi abi.ABI + var gasPriceErr, gasCostErr error + + l1GasPriceAddress = ArbGasInfoAddress + gasPriceMethod = ArbGasInfo_getL1BaseFeeEstimate + l1GasPriceMethodAbi, gasPriceErr = abi.JSON(strings.NewReader(GetL1BaseFeeEstimateAbiString)) + l1GasCostAddress = ArbNodeInterfaceAddress + gasCostMethod = ArbNodeInterface_gasEstimateL1Component + l1GasCostMethodAbi, gasCostErr = abi.JSON(strings.NewReader(GasEstimateL1ComponentAbiString)) + + if gasPriceErr != nil { + panic("Failed to parse L1 gas price method ABI for chain: arbitrum") + } + if gasCostErr != nil { + panic("Failed to parse L1 gas cost method ABI for chain: arbitrum") + } + + return &arbitrumL1Oracle{ + client: ethClient, + pollPeriod: PollPeriod, + logger: logger.Sugared(logger.Named(lggr, "L1GasOracle(arbitrum)")), + chainType: config.ChainArbitrum, + + l1GasPriceAddress: l1GasPriceAddress, + gasPriceMethod: gasPriceMethod, + l1GasPriceMethodAbi: l1GasPriceMethodAbi, + l1GasCostAddress: l1GasCostAddress, + gasCostMethod: gasCostMethod, + l1GasCostMethodAbi: l1GasCostMethodAbi, + + chInitialised: make(chan struct{}), + chStop: make(chan struct{}), + chDone: make(chan struct{}), + } +} + +func (o *arbitrumL1Oracle) Name() string { + return o.logger.Name() +} + +func (o *arbitrumL1Oracle) Start(ctx context.Context) error { + return o.StartOnce(o.Name(), func() error { + go o.run() + <-o.chInitialised + return nil + }) +} +func (o *arbitrumL1Oracle) Close() error { + return o.StopOnce(o.Name(), func() error { + close(o.chStop) + <-o.chDone + return nil + }) +} + +func (o *arbitrumL1Oracle) HealthReport() map[string]error { + return map[string]error{o.Name(): o.Healthy()} +} + +func (o *arbitrumL1Oracle) run() { + defer close(o.chDone) + + t := o.refresh() + close(o.chInitialised) + + for { + select { + case <-o.chStop: + return + case <-t.C: + t = o.refresh() + } + } +} +func (o *arbitrumL1Oracle) refresh() (t *time.Timer) { + t, err := o.refreshWithError() + if err != nil { + o.SvcErrBuffer.Append(err) + } + return +} + +func (o *arbitrumL1Oracle) refreshWithError() (t *time.Timer, err error) { + t = time.NewTimer(utils.WithJitter(o.pollPeriod)) + + ctx, cancel := o.chStop.CtxCancel(evmclient.ContextWithDefaultTimeout()) + defer cancel() + + price, err := o.fetchL1GasPrice(ctx) + if err != nil { + return t, err + } + + o.l1GasPriceMu.Lock() + defer o.l1GasPriceMu.Unlock() + o.l1GasPrice = priceEntry{price: assets.NewWei(price), timestamp: time.Now()} + return +} + +func (o *arbitrumL1Oracle) fetchL1GasPrice(ctx context.Context) (price *big.Int, err error) { + var callData, b []byte + precompile := common.HexToAddress(o.l1GasPriceAddress) + callData, err = o.l1GasPriceMethodAbi.Pack(o.gasPriceMethod) + if err != nil { + errMsg := "failed to pack calldata for arbitrum L1 gas price method" + o.logger.Errorf(errMsg) + return nil, fmt.Errorf("%s: %w", errMsg, err) + } + b, err = o.client.CallContract(ctx, ethereum.CallMsg{ + To: &precompile, + Data: callData, + }, nil) + if err != nil { + errMsg := "gas oracle contract call failed" + o.logger.Errorf(errMsg) + return nil, fmt.Errorf("%s: %w", errMsg, err) + } + + if len(b) != 32 { // returns uint256; + errMsg := fmt.Sprintf("return data length (%d) different than expected (%d)", len(b), 32) + o.logger.Criticalf(errMsg) + return nil, fmt.Errorf(errMsg) + } + price = new(big.Int).SetBytes(b) + return price, nil +} + +func (o *arbitrumL1Oracle) GasPrice(_ context.Context) (l1GasPrice *assets.Wei, err error) { + var timestamp time.Time + ok := o.IfStarted(func() { + o.l1GasPriceMu.RLock() + l1GasPrice = o.l1GasPrice.price + timestamp = o.l1GasPrice.timestamp + o.l1GasPriceMu.RUnlock() + }) + if !ok { + return l1GasPrice, fmt.Errorf("L1GasOracle is not started; cannot estimate gas") + } + if l1GasPrice == nil { + return l1GasPrice, fmt.Errorf("failed to get l1 gas price; gas price not set") + } + // Validate the price has been updated within the pollPeriod * 2 + // Allowing double the poll period before declaring the price stale to give ample time for the refresh to process + if time.Since(timestamp) > o.pollPeriod*2 { + return l1GasPrice, fmt.Errorf("gas price is stale") + } + return +} + +// Gets the L1 gas cost for the provided transaction at the specified block num +// If block num is not provided, the value on the latest block num is used +func (o *arbitrumL1Oracle) GetGasCost(ctx context.Context, tx *gethtypes.Transaction, blockNum *big.Int) (*assets.Wei, error) { + ctx, cancel := context.WithTimeout(ctx, client.QueryTimeout) + defer cancel() + var callData, b []byte + var err error + + if callData, err = o.l1GasCostMethodAbi.Pack(o.gasCostMethod, tx.To(), false, tx.Data()); err != nil { + return nil, fmt.Errorf("failed to pack calldata for %s L1 gas cost estimation method: %w", o.chainType, err) + } + + precompile := common.HexToAddress(o.l1GasCostAddress) + b, err = o.client.CallContract(ctx, ethereum.CallMsg{ + To: &precompile, + Data: callData, + }, blockNum) + if err != nil { + errorMsg := fmt.Sprintf("gas oracle contract call failed: %v", err) + o.logger.Errorf(errorMsg) + return nil, fmt.Errorf(errorMsg) + } + + var l1GasCost *big.Int + + if len(b) != 8+2*32 { // returns (uint64 gasEstimateForL1, uint256 baseFee, uint256 l1BaseFeeEstimate); + errorMsg := fmt.Sprintf("return data length (%d) different than expected (%d)", len(b), 8+2*32) + o.logger.Critical(errorMsg) + return nil, fmt.Errorf(errorMsg) + } + l1GasCost = new(big.Int).SetBytes(b[:8]) + + return assets.NewWei(l1GasCost), nil +} + +// callGetPricesInArbGas calls ArbGasInfo.getPricesInArbGas() on the precompile contract ArbGasInfoAddress. +// +// @return (per L2 tx, per L1 calldata unit, per storage allocation) +// function getPricesInArbGas() external view returns (uint256, uint256, uint256); +// +// https://github.com/OffchainLabs/nitro/blob/f7645453cfc77bf3e3644ea1ac031eff629df325/contracts/src/precompiles/ArbGasInfo.sol#L69 + +func (o *arbitrumL1Oracle) GetPricesInArbGas() (perL2Tx uint32, perL1CalldataUnit uint32, err error) { + ctx, cancel := o.chStop.CtxCancel(evmclient.ContextWithDefaultTimeout()) + defer cancel() + precompile := common.HexToAddress(ArbGasInfoAddress) + b, err := o.client.CallContract(ctx, ethereum.CallMsg{ + To: &precompile, + Data: common.Hex2Bytes(ArbGasInfo_getPricesInArbGas), + }, big.NewInt(-1)) + if err != nil { + return 0, 0, err + } + + if len(b) != 3*32 { // returns (uint256, uint256, uint256); + err = fmt.Errorf("return data length (%d) different than expected (%d)", len(b), 3*32) + return + } + bPerL2Tx := new(big.Int).SetBytes(b[:32]) + bPerL1CalldataUnit := new(big.Int).SetBytes(b[32:64]) + // ignore perStorageAllocation + if !bPerL2Tx.IsUint64() || !bPerL1CalldataUnit.IsUint64() { + err = fmt.Errorf("returned integers are not uint64 (%s, %s)", bPerL2Tx.String(), bPerL1CalldataUnit.String()) + return + } + + perL2TxU64 := bPerL2Tx.Uint64() + perL1CalldataUnitU64 := bPerL1CalldataUnit.Uint64() + if perL2TxU64 > math.MaxUint32 || perL1CalldataUnitU64 > math.MaxUint32 { + err = fmt.Errorf("returned integers are not uint32 (%d, %d)", perL2TxU64, perL1CalldataUnitU64) + return + } + perL2Tx = uint32(perL2TxU64) + perL1CalldataUnit = uint32(perL1CalldataUnitU64) + return +} diff --git a/core/chains/evm/gas/rollups/l1_oracle.go b/core/chains/evm/gas/rollups/l1_oracle.go index ae46071cf0d..05ceb720ab2 100644 --- a/core/chains/evm/gas/rollups/l1_oracle.go +++ b/core/chains/evm/gas/rollups/l1_oracle.go @@ -5,36 +5,35 @@ import ( "fmt" "math/big" "slices" - "strings" - "sync" "time" "github.com/ethereum/go-ethereum" - "github.com/ethereum/go-ethereum/accounts/abi" - "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/rpc" - "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" - "github.com/smartcontractkit/chainlink-common/pkg/utils" - gethtypes "github.com/ethereum/go-ethereum/core/types" + "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink/v2/common/client" "github.com/smartcontractkit/chainlink/v2/common/config" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" - evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" ) -//go:generate mockery --quiet --name ethClient --output ./mocks/ --case=underscore --structname ETHClient -type ethClient interface { - CallContract(ctx context.Context, msg ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) - BatchCallContext(ctx context.Context, b []rpc.BatchElem) error +// L1Oracle provides interface for fetching L1-specific fee components if the chain is an L2. +// For example, on Optimistic Rollups, this oracle can return rollup-specific l1BaseFee +// +//go:generate mockery --quiet --name L1Oracle --output ./mocks/ --case=underscore +type L1Oracle interface { + services.Service + + GasPrice(ctx context.Context) (*assets.Wei, error) + GetGasCost(ctx context.Context, tx *types.Transaction, blockNum *big.Int) (*assets.Wei, error) } -//go:generate mockery --quiet --name daPriceReader --output ./mocks/ --case=underscore --structname DAPriceReader -type daPriceReader interface { - GetDAGasPrice(ctx context.Context) (*big.Int, error) +//go:generate mockery --quiet --name l1OracleClient --output ./mocks/ --case=underscore --structname L1OracleClient +type l1OracleClient interface { + CallContract(ctx context.Context, msg ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) + BatchCallContext(ctx context.Context, b []rpc.BatchElem) error } type priceEntry struct { @@ -42,71 +41,7 @@ type priceEntry struct { timestamp time.Time } -// Reads L2-specific precompiles and caches the l1GasPrice set by the L2. -type l1Oracle struct { - services.StateMachine - client ethClient - pollPeriod time.Duration - logger logger.SugaredLogger - chainType config.ChainType - - l1GasPriceAddress string - gasPriceMethod string - l1GasPriceMethodAbi abi.ABI - l1GasPriceMu sync.RWMutex - l1GasPrice priceEntry - - l1GasCostAddress string - gasCostMethod string - l1GasCostMethodAbi abi.ABI - - priceReader daPriceReader - - chInitialised chan struct{} - chStop services.StopChan - chDone chan struct{} -} - const ( - // ArbGasInfoAddress is the address of the "Precompiled contract that exists in every Arbitrum chain." - // https://github.com/OffchainLabs/nitro/blob/f7645453cfc77bf3e3644ea1ac031eff629df325/contracts/src/precompiles/ArbGasInfo.sol - ArbGasInfoAddress = "0x000000000000000000000000000000000000006C" - // ArbGasInfo_getL1BaseFeeEstimate is the a hex encoded call to: - // `function getL1BaseFeeEstimate() external view returns (uint256);` - ArbGasInfo_getL1BaseFeeEstimate = "getL1BaseFeeEstimate" - // NodeInterfaceAddress is the address of the precompiled contract that is only available through RPC - // https://github.com/OffchainLabs/nitro/blob/e815395d2e91fb17f4634cad72198f6de79c6e61/nodeInterface/NodeInterface.go#L37 - ArbNodeInterfaceAddress = "0x00000000000000000000000000000000000000C8" - // ArbGasInfo_getPricesInArbGas is the a hex encoded call to: - // `function gasEstimateL1Component(address to, bool contractCreation, bytes calldata data) external payable returns (uint64 gasEstimateForL1, uint256 baseFee, uint256 l1BaseFeeEstimate);` - ArbNodeInterface_gasEstimateL1Component = "gasEstimateL1Component" - - // OPGasOracleAddress is the address of the precompiled contract that exists on OP stack chain. - // This is the case for Optimism and Base. - OPGasOracleAddress = "0x420000000000000000000000000000000000000F" - // OPGasOracle_l1BaseFee is a hex encoded call to: - // `function l1BaseFee() external view returns (uint256);` - OPGasOracle_l1BaseFee = "l1BaseFee" - // OPGasOracle_getL1Fee is a hex encoded call to: - // `function getL1Fee(bytes) external view returns (uint256);` - OPGasOracle_getL1Fee = "getL1Fee" - - // ScrollGasOracleAddress is the address of the precompiled contract that exists on Scroll chain. - ScrollGasOracleAddress = "0x5300000000000000000000000000000000000002" - // ScrollGasOracle_l1BaseFee is a hex encoded call to: - // `function l1BaseFee() external view returns (uint256);` - ScrollGasOracle_l1BaseFee = "l1BaseFee" - // ScrollGasOracle_getL1Fee is a hex encoded call to: - // `function getL1Fee(bytes) external view returns (uint256);` - ScrollGasOracle_getL1Fee = "getL1Fee" - - // GasOracleAddress is the address of the precompiled contract that exists on Kroma chain. - // This is the case for Kroma. - KromaGasOracleAddress = "0x4200000000000000000000000000000000000005" - // GasOracle_l1BaseFee is the a hex encoded call to: - // `function l1BaseFee() external view returns (uint256);` - KromaGasOracle_l1BaseFee = "l1BaseFee" - // Interval at which to poll for L1BaseFee. A good starting point is the L1 block time. PollPeriod = 6 * time.Second ) @@ -117,253 +52,18 @@ func IsRollupWithL1Support(chainType config.ChainType) bool { return slices.Contains(supportedChainTypes, chainType) } -func NewL1GasOracle(lggr logger.Logger, ethClient ethClient, chainType config.ChainType) L1Oracle { - var priceReader daPriceReader - switch chainType { - case config.ChainOptimismBedrock: - priceReader = newOPPriceReader(lggr, ethClient, chainType, OPGasOracleAddress) - case config.ChainKroma: - priceReader = newOPPriceReader(lggr, ethClient, chainType, KromaGasOracleAddress) - default: - priceReader = nil +func NewL1GasOracle(lggr logger.Logger, ethClient l1OracleClient, chainType config.ChainType) L1Oracle { + if !IsRollupWithL1Support(chainType) { + return nil } - return newL1GasOracle(lggr, ethClient, chainType, priceReader) -} - -func newL1GasOracle(lggr logger.Logger, ethClient ethClient, chainType config.ChainType, priceReader daPriceReader) L1Oracle { - var l1GasPriceAddress, gasPriceMethod, l1GasCostAddress, gasCostMethod string - var l1GasPriceMethodAbi, l1GasCostMethodAbi abi.ABI - var gasPriceErr, gasCostErr error - + var l1Oracle L1Oracle switch chainType { + case config.ChainOptimismBedrock, config.ChainKroma, config.ChainScroll: + l1Oracle = NewOpStackL1GasOracle(lggr, ethClient, chainType) case config.ChainArbitrum: - l1GasPriceAddress = ArbGasInfoAddress - gasPriceMethod = ArbGasInfo_getL1BaseFeeEstimate - l1GasPriceMethodAbi, gasPriceErr = abi.JSON(strings.NewReader(GetL1BaseFeeEstimateAbiString)) - l1GasCostAddress = ArbNodeInterfaceAddress - gasCostMethod = ArbNodeInterface_gasEstimateL1Component - l1GasCostMethodAbi, gasCostErr = abi.JSON(strings.NewReader(GasEstimateL1ComponentAbiString)) - case config.ChainOptimismBedrock: - l1GasPriceAddress = OPGasOracleAddress - gasPriceMethod = OPGasOracle_l1BaseFee - l1GasPriceMethodAbi, gasPriceErr = abi.JSON(strings.NewReader(L1BaseFeeAbiString)) - l1GasCostAddress = OPGasOracleAddress - gasCostMethod = OPGasOracle_getL1Fee - l1GasCostMethodAbi, gasCostErr = abi.JSON(strings.NewReader(GetL1FeeAbiString)) - case config.ChainKroma: - l1GasPriceAddress = KromaGasOracleAddress - gasPriceMethod = KromaGasOracle_l1BaseFee - l1GasPriceMethodAbi, gasPriceErr = abi.JSON(strings.NewReader(L1BaseFeeAbiString)) - l1GasCostAddress = "" - gasCostMethod = "" - case config.ChainScroll: - l1GasPriceAddress = ScrollGasOracleAddress - gasPriceMethod = ScrollGasOracle_l1BaseFee - l1GasPriceMethodAbi, gasPriceErr = abi.JSON(strings.NewReader(L1BaseFeeAbiString)) - l1GasCostAddress = ScrollGasOracleAddress - gasCostMethod = ScrollGasOracle_getL1Fee - l1GasCostMethodAbi, gasCostErr = abi.JSON(strings.NewReader(GetL1FeeAbiString)) + l1Oracle = NewArbitrumL1GasOracle(lggr, ethClient) default: panic(fmt.Sprintf("Received unspported chaintype %s", chainType)) } - - if gasPriceErr != nil { - panic(fmt.Sprintf("Failed to parse L1 gas price method ABI for chain: %s", chainType)) - } - if gasCostErr != nil { - panic(fmt.Sprintf("Failed to parse L1 gas cost method ABI for chain: %s", chainType)) - } - - return &l1Oracle{ - client: ethClient, - pollPeriod: PollPeriod, - logger: logger.Sugared(logger.Named(lggr, fmt.Sprintf("L1GasOracle(%s)", chainType))), - chainType: chainType, - - l1GasPriceAddress: l1GasPriceAddress, - gasPriceMethod: gasPriceMethod, - l1GasPriceMethodAbi: l1GasPriceMethodAbi, - l1GasCostAddress: l1GasCostAddress, - gasCostMethod: gasCostMethod, - l1GasCostMethodAbi: l1GasCostMethodAbi, - - priceReader: priceReader, - - chInitialised: make(chan struct{}), - chStop: make(chan struct{}), - chDone: make(chan struct{}), - } -} - -func (o *l1Oracle) Name() string { - return o.logger.Name() -} - -func (o *l1Oracle) Start(ctx context.Context) error { - return o.StartOnce(o.Name(), func() error { - go o.run() - <-o.chInitialised - return nil - }) -} -func (o *l1Oracle) Close() error { - return o.StopOnce(o.Name(), func() error { - close(o.chStop) - <-o.chDone - return nil - }) -} - -func (o *l1Oracle) HealthReport() map[string]error { - return map[string]error{o.Name(): o.Healthy()} -} - -func (o *l1Oracle) run() { - defer close(o.chDone) - - t := o.refresh() - close(o.chInitialised) - - for { - select { - case <-o.chStop: - return - case <-t.C: - t = o.refresh() - } - } -} -func (o *l1Oracle) refresh() (t *time.Timer) { - t, err := o.refreshWithError() - if err != nil { - o.SvcErrBuffer.Append(err) - } - return -} - -func (o *l1Oracle) refreshWithError() (t *time.Timer, err error) { - t = time.NewTimer(utils.WithJitter(o.pollPeriod)) - - ctx, cancel := o.chStop.CtxCancel(evmclient.ContextWithDefaultTimeout()) - defer cancel() - - price, err := o.fetchL1GasPrice(ctx) - if err != nil { - return t, err - } - - o.l1GasPriceMu.Lock() - defer o.l1GasPriceMu.Unlock() - o.l1GasPrice = priceEntry{price: assets.NewWei(price), timestamp: time.Now()} - return -} - -func (o *l1Oracle) fetchL1GasPrice(ctx context.Context) (price *big.Int, err error) { - // if dedicated priceReader exists, use the reader - if o.priceReader != nil { - return o.priceReader.GetDAGasPrice(ctx) - } - - var callData, b []byte - precompile := common.HexToAddress(o.l1GasPriceAddress) - callData, err = o.l1GasPriceMethodAbi.Pack(o.gasPriceMethod) - if err != nil { - errMsg := fmt.Sprintf("failed to pack calldata for %s L1 gas price method", o.chainType) - o.logger.Errorf(errMsg) - return nil, fmt.Errorf("%s: %w", errMsg, err) - } - b, err = o.client.CallContract(ctx, ethereum.CallMsg{ - To: &precompile, - Data: callData, - }, nil) - if err != nil { - errMsg := "gas oracle contract call failed" - o.logger.Errorf(errMsg) - return nil, fmt.Errorf("%s: %w", errMsg, err) - } - - if len(b) != 32 { // returns uint256; - errMsg := fmt.Sprintf("return data length (%d) different than expected (%d)", len(b), 32) - o.logger.Criticalf(errMsg) - return nil, fmt.Errorf(errMsg) - } - price = new(big.Int).SetBytes(b) - return price, nil -} - -func (o *l1Oracle) GasPrice(_ context.Context) (l1GasPrice *assets.Wei, err error) { - var timestamp time.Time - ok := o.IfStarted(func() { - o.l1GasPriceMu.RLock() - l1GasPrice = o.l1GasPrice.price - timestamp = o.l1GasPrice.timestamp - o.l1GasPriceMu.RUnlock() - }) - if !ok { - return l1GasPrice, fmt.Errorf("L1GasOracle is not started; cannot estimate gas") - } - if l1GasPrice == nil { - return l1GasPrice, fmt.Errorf("failed to get l1 gas price; gas price not set") - } - // Validate the price has been updated within the pollPeriod * 2 - // Allowing double the poll period before declaring the price stale to give ample time for the refresh to process - if time.Since(timestamp) > o.pollPeriod*2 { - return l1GasPrice, fmt.Errorf("gas price is stale") - } - return -} - -// Gets the L1 gas cost for the provided transaction at the specified block num -// If block num is not provided, the value on the latest block num is used -func (o *l1Oracle) GetGasCost(ctx context.Context, tx *gethtypes.Transaction, blockNum *big.Int) (*assets.Wei, error) { - ctx, cancel := context.WithTimeout(ctx, client.QueryTimeout) - defer cancel() - var callData, b []byte - var err error - if o.chainType == config.ChainOptimismBedrock || o.chainType == config.ChainScroll { - // Append rlp-encoded tx - var encodedtx []byte - if encodedtx, err = tx.MarshalBinary(); err != nil { - return nil, fmt.Errorf("failed to marshal tx for gas cost estimation: %w", err) - } - if callData, err = o.l1GasCostMethodAbi.Pack(o.gasCostMethod, encodedtx); err != nil { - return nil, fmt.Errorf("failed to pack calldata for %s L1 gas cost estimation method: %w", o.chainType, err) - } - } else if o.chainType == config.ChainArbitrum { - if callData, err = o.l1GasCostMethodAbi.Pack(o.gasCostMethod, tx.To(), false, tx.Data()); err != nil { - return nil, fmt.Errorf("failed to pack calldata for %s L1 gas cost estimation method: %w", o.chainType, err) - } - } else { - return nil, fmt.Errorf("L1 gas cost not supported for this chain: %s", o.chainType) - } - - precompile := common.HexToAddress(o.l1GasCostAddress) - b, err = o.client.CallContract(ctx, ethereum.CallMsg{ - To: &precompile, - Data: callData, - }, blockNum) - if err != nil { - errorMsg := fmt.Sprintf("gas oracle contract call failed: %v", err) - o.logger.Errorf(errorMsg) - return nil, fmt.Errorf(errorMsg) - } - - var l1GasCost *big.Int - if o.chainType == config.ChainOptimismBedrock || o.chainType == config.ChainScroll { - if len(b) != 32 { // returns uint256; - errorMsg := fmt.Sprintf("return data length (%d) different than expected (%d)", len(b), 32) - o.logger.Critical(errorMsg) - return nil, fmt.Errorf(errorMsg) - } - l1GasCost = new(big.Int).SetBytes(b) - } else if o.chainType == config.ChainArbitrum { - if len(b) != 8+2*32 { // returns (uint64 gasEstimateForL1, uint256 baseFee, uint256 l1BaseFeeEstimate); - errorMsg := fmt.Sprintf("return data length (%d) different than expected (%d)", len(b), 8+2*32) - o.logger.Critical(errorMsg) - return nil, fmt.Errorf(errorMsg) - } - l1GasCost = new(big.Int).SetBytes(b[:8]) - } - - return assets.NewWei(l1GasCost), nil + return l1Oracle } diff --git a/core/chains/evm/gas/rollups/l1_oracle_test.go b/core/chains/evm/gas/rollups/l1_oracle_test.go index 4f3b67e2ecf..6efdda6bcff 100644 --- a/core/chains/evm/gas/rollups/l1_oracle_test.go +++ b/core/chains/evm/gas/rollups/l1_oracle_test.go @@ -1,6 +1,7 @@ package rollups import ( + "errors" "math/big" "strings" "testing" @@ -27,9 +28,9 @@ func TestL1Oracle(t *testing.T) { t.Parallel() t.Run("Unsupported ChainType returns nil", func(t *testing.T) { - ethClient := mocks.NewETHClient(t) + ethClient := mocks.NewL1OracleClient(t) - assert.Panicsf(t, func() { NewL1GasOracle(logger.Test(t), ethClient, config.ChainCelo) }, "Received unspported chaintype %s", config.ChainCelo) + assert.Nil(t, NewL1GasOracle(logger.Test(t), ethClient, config.ChainCelo)) }) } @@ -37,7 +38,7 @@ func TestL1Oracle_GasPrice(t *testing.T) { t.Parallel() t.Run("Calling GasPrice on unstarted L1Oracle returns error", func(t *testing.T) { - ethClient := mocks.NewETHClient(t) + ethClient := mocks.NewL1OracleClient(t) oracle := NewL1GasOracle(logger.Test(t), ethClient, config.ChainOptimismBedrock) @@ -50,7 +51,7 @@ func TestL1Oracle_GasPrice(t *testing.T) { l1GasPriceMethodAbi, err := abi.JSON(strings.NewReader(GetL1BaseFeeEstimateAbiString)) require.NoError(t, err) - ethClient := mocks.NewETHClient(t) + ethClient := mocks.NewL1OracleClient(t) ethClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { callMsg := args.Get(1).(ethereum.CallMsg) blockNumber := args.Get(2).(*big.Int) @@ -73,10 +74,34 @@ func TestL1Oracle_GasPrice(t *testing.T) { t.Run("Calling GasPrice on started Kroma L1Oracle returns Kroma l1GasPrice", func(t *testing.T) { l1BaseFee := big.NewInt(100) - priceReader := mocks.NewDAPriceReader(t) - priceReader.On("GetDAGasPrice", mock.Anything).Return(l1BaseFee, nil) + l1GasPriceMethodAbi, err := abi.JSON(strings.NewReader(L1BaseFeeAbiString)) + require.NoError(t, err) + + isEcotoneAbiString, err := abi.JSON(strings.NewReader(OPIsEcotoneAbiString)) + require.NoError(t, err) + + ethClient := mocks.NewL1OracleClient(t) + ethClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { + callMsg := args.Get(1).(ethereum.CallMsg) + blockNumber := args.Get(2).(*big.Int) + var payload []byte + payload, err = isEcotoneAbiString.Pack("isEcotone") + require.NoError(t, err) + require.Equal(t, payload, callMsg.Data) + assert.Nil(t, blockNumber) + }).Return(nil, errors.New("not ecotone")).Once() - oracle := newL1GasOracle(logger.Test(t), nil, config.ChainKroma, priceReader) + ethClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { + callMsg := args.Get(1).(ethereum.CallMsg) + blockNumber := args.Get(2).(*big.Int) + var payload []byte + payload, err = l1GasPriceMethodAbi.Pack("l1BaseFee") + require.NoError(t, err) + require.Equal(t, payload, callMsg.Data) + assert.Nil(t, blockNumber) + }).Return(common.BigToHash(l1BaseFee).Bytes(), nil) + + oracle := newOpStackL1GasOracle(logger.Test(t), ethClient, config.ChainKroma, KromaGasOracleAddress) servicetest.RunHealthy(t, oracle) gasPrice, err := oracle.GasPrice(testutils.Context(t)) @@ -88,10 +113,34 @@ func TestL1Oracle_GasPrice(t *testing.T) { t.Run("Calling GasPrice on started OPStack L1Oracle returns OPStack l1GasPrice", func(t *testing.T) { l1BaseFee := big.NewInt(100) - priceReader := mocks.NewDAPriceReader(t) - priceReader.On("GetDAGasPrice", mock.Anything).Return(l1BaseFee, nil) + l1GasPriceMethodAbi, err := abi.JSON(strings.NewReader(L1BaseFeeAbiString)) + require.NoError(t, err) + + isEcotoneAbiString, err := abi.JSON(strings.NewReader(OPIsEcotoneAbiString)) + require.NoError(t, err) + + ethClient := mocks.NewL1OracleClient(t) + ethClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { + callMsg := args.Get(1).(ethereum.CallMsg) + blockNumber := args.Get(2).(*big.Int) + var payload []byte + payload, err = isEcotoneAbiString.Pack("isEcotone") + require.NoError(t, err) + require.Equal(t, payload, callMsg.Data) + assert.Nil(t, blockNumber) + }).Return(nil, errors.New("not ecotone")).Once() - oracle := newL1GasOracle(logger.Test(t), nil, config.ChainOptimismBedrock, priceReader) + ethClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { + callMsg := args.Get(1).(ethereum.CallMsg) + blockNumber := args.Get(2).(*big.Int) + var payload []byte + payload, err = l1GasPriceMethodAbi.Pack("l1BaseFee") + require.NoError(t, err) + require.Equal(t, payload, callMsg.Data) + assert.Nil(t, blockNumber) + }).Return(common.BigToHash(l1BaseFee).Bytes(), nil) + + oracle := newOpStackL1GasOracle(logger.Test(t), ethClient, config.ChainOptimismBedrock, OPGasOracleAddress) servicetest.RunHealthy(t, oracle) gasPrice, err := oracle.GasPrice(testutils.Context(t)) @@ -105,7 +154,20 @@ func TestL1Oracle_GasPrice(t *testing.T) { l1GasPriceMethodAbi, err := abi.JSON(strings.NewReader(L1BaseFeeAbiString)) require.NoError(t, err) - ethClient := mocks.NewETHClient(t) + isEcotoneAbiString, err := abi.JSON(strings.NewReader(OPIsEcotoneAbiString)) + require.NoError(t, err) + + ethClient := mocks.NewL1OracleClient(t) + ethClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { + callMsg := args.Get(1).(ethereum.CallMsg) + blockNumber := args.Get(2).(*big.Int) + var payload []byte + payload, err = isEcotoneAbiString.Pack("isEcotone") + require.NoError(t, err) + require.Equal(t, payload, callMsg.Data) + assert.Nil(t, blockNumber) + }).Return(nil, errors.New("not ecotone")).Once() + ethClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { callMsg := args.Get(1).(ethereum.CallMsg) blockNumber := args.Get(2).(*big.Int) @@ -149,7 +211,7 @@ func TestL1Oracle_GetGasCost(t *testing.T) { result = append(result, baseFee...) result = append(result, l1BaseFeeEstimate...) - ethClient := mocks.NewETHClient(t) + ethClient := mocks.NewL1OracleClient(t) ethClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { callMsg := args.Get(1).(ethereum.CallMsg) blockNumber := args.Get(2).(*big.Int) @@ -171,7 +233,7 @@ func TestL1Oracle_GetGasCost(t *testing.T) { blockNum := big.NewInt(1000) tx := types.NewTx(&types.LegacyTx{}) - ethClient := mocks.NewETHClient(t) + ethClient := mocks.NewL1OracleClient(t) oracle := NewL1GasOracle(logger.Test(t), ethClient, config.ChainKroma) _, err := oracle.GetGasCost(testutils.Context(t), tx, blockNum) @@ -195,7 +257,7 @@ func TestL1Oracle_GetGasCost(t *testing.T) { encodedTx, err := tx.MarshalBinary() require.NoError(t, err) - ethClient := mocks.NewETHClient(t) + ethClient := mocks.NewL1OracleClient(t) ethClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { callMsg := args.Get(1).(ethereum.CallMsg) blockNumber := args.Get(2).(*big.Int) @@ -230,7 +292,7 @@ func TestL1Oracle_GetGasCost(t *testing.T) { encodedTx, err := tx.MarshalBinary() require.NoError(t, err) - ethClient := mocks.NewETHClient(t) + ethClient := mocks.NewL1OracleClient(t) ethClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { callMsg := args.Get(1).(ethereum.CallMsg) blockNumber := args.Get(2).(*big.Int) diff --git a/core/chains/evm/gas/rollups/mocks/da_price_reader.go b/core/chains/evm/gas/rollups/mocks/da_price_reader.go deleted file mode 100644 index 7758f53e436..00000000000 --- a/core/chains/evm/gas/rollups/mocks/da_price_reader.go +++ /dev/null @@ -1,59 +0,0 @@ -// Code generated by mockery v2.38.0. DO NOT EDIT. - -package mocks - -import ( - context "context" - big "math/big" - - mock "github.com/stretchr/testify/mock" -) - -// DAPriceReader is an autogenerated mock type for the daPriceReader type -type DAPriceReader struct { - mock.Mock -} - -// GetDAGasPrice provides a mock function with given fields: ctx -func (_m *DAPriceReader) GetDAGasPrice(ctx context.Context) (*big.Int, error) { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for GetDAGasPrice") - } - - var r0 *big.Int - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*big.Int, error)); ok { - return rf(ctx) - } - if rf, ok := ret.Get(0).(func(context.Context) *big.Int); ok { - r0 = rf(ctx) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*big.Int) - } - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// NewDAPriceReader creates a new instance of DAPriceReader. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewDAPriceReader(t interface { - mock.TestingT - Cleanup(func()) -}) *DAPriceReader { - mock := &DAPriceReader{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/core/chains/evm/gas/rollups/mocks/eth_client.go b/core/chains/evm/gas/rollups/mocks/l1_oracle_client.go similarity index 72% rename from core/chains/evm/gas/rollups/mocks/eth_client.go rename to core/chains/evm/gas/rollups/mocks/l1_oracle_client.go index e5a28f715ad..3995a09513b 100644 --- a/core/chains/evm/gas/rollups/mocks/eth_client.go +++ b/core/chains/evm/gas/rollups/mocks/l1_oracle_client.go @@ -13,13 +13,13 @@ import ( rpc "github.com/ethereum/go-ethereum/rpc" ) -// ETHClient is an autogenerated mock type for the ethClient type -type ETHClient struct { +// L1OracleClient is an autogenerated mock type for the l1OracleClient type +type L1OracleClient struct { mock.Mock } // BatchCallContext provides a mock function with given fields: ctx, b -func (_m *ETHClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem) error { +func (_m *L1OracleClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem) error { ret := _m.Called(ctx, b) if len(ret) == 0 { @@ -37,7 +37,7 @@ func (_m *ETHClient) BatchCallContext(ctx context.Context, b []rpc.BatchElem) er } // CallContract provides a mock function with given fields: ctx, msg, blockNumber -func (_m *ETHClient) CallContract(ctx context.Context, msg ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) { +func (_m *L1OracleClient) CallContract(ctx context.Context, msg ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) { ret := _m.Called(ctx, msg, blockNumber) if len(ret) == 0 { @@ -66,13 +66,13 @@ func (_m *ETHClient) CallContract(ctx context.Context, msg ethereum.CallMsg, blo return r0, r1 } -// NewETHClient creates a new instance of ETHClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// NewL1OracleClient creates a new instance of L1OracleClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. -func NewETHClient(t interface { +func NewL1OracleClient(t interface { mock.TestingT Cleanup(func()) -}) *ETHClient { - mock := ÐClient{} +}) *L1OracleClient { + mock := &L1OracleClient{} mock.Mock.Test(t) t.Cleanup(func() { mock.AssertExpectations(t) }) diff --git a/core/chains/evm/gas/rollups/models.go b/core/chains/evm/gas/rollups/models.go deleted file mode 100644 index 7aa3d4059dd..00000000000 --- a/core/chains/evm/gas/rollups/models.go +++ /dev/null @@ -1,22 +0,0 @@ -package rollups - -import ( - "context" - "math/big" - - "github.com/ethereum/go-ethereum/core/types" - - "github.com/smartcontractkit/chainlink-common/pkg/services" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" -) - -// L1Oracle provides interface for fetching L1-specific fee components if the chain is an L2. -// For example, on Optimistic Rollups, this oracle can return rollup-specific l1BaseFee -// -//go:generate mockery --quiet --name L1Oracle --output ./mocks/ --case=underscore -type L1Oracle interface { - services.Service - - GasPrice(ctx context.Context) (*assets.Wei, error) - GetGasCost(ctx context.Context, tx *types.Transaction, blockNum *big.Int) (*assets.Wei, error) -} diff --git a/core/chains/evm/gas/rollups/op_l1_oracle.go b/core/chains/evm/gas/rollups/op_l1_oracle.go new file mode 100644 index 00000000000..e180777fb61 --- /dev/null +++ b/core/chains/evm/gas/rollups/op_l1_oracle.go @@ -0,0 +1,431 @@ +package rollups + +import ( + "context" + "fmt" + "math/big" + "strings" + "sync" + "time" + + "github.com/ethereum/go-ethereum" + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/ethereum/go-ethereum/rpc" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/utils" + + gethtypes "github.com/ethereum/go-ethereum/core/types" + + "github.com/smartcontractkit/chainlink/v2/common/client" + "github.com/smartcontractkit/chainlink/v2/common/config" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" + evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" +) + +// Reads L2-specific precompiles and caches the l1GasPrice set by the L2. +type OptimismL1Oracle struct { + services.StateMachine + client l1OracleClient + pollPeriod time.Duration + logger logger.SugaredLogger + chainType config.ChainType + + l1OracleAddress string + gasPriceMethod string + l1GasPriceMethodAbi abi.ABI + l1GasPriceMu sync.RWMutex + l1GasPrice priceEntry + + gasCostMethod string + l1GasCostMethodAbi abi.ABI + + chInitialised chan struct{} + chStop services.StopChan + chDone chan struct{} + + isEcotoneMethodAbi abi.ABI + + l1BaseFeeCalldata []byte + isEcotoneCalldata []byte + getL1GasUsedCalldata []byte + getL1FeeCalldata []byte + + isEcotone bool + isEcotoneCheckTs int64 +} + +const ( + // OPStackGasOracle_isEcotone fetches if the OP Stack GasPriceOracle contract has upgraded to Ecotone + OPStackGasOracle_isEcotone = "isEcotone" + // OPStackGasOracle_getL1GasUsed fetches the l1 gas used for given tx bytes + OPStackGasOracle_getL1GasUsed = "getL1GasUsed" + // OPStackGasOracle_isEcotonePollingPeriod is the interval to poll if chain has upgraded to Ecotone + // Set to poll every 4 hours + OPStackGasOracle_isEcotonePollingPeriod = 14400 + // OPStackGasOracleAddress is the address of the precompiled contract that exists on OP stack chain. + // OPStackGasOracle_l1BaseFee fetches the l1 base fee set in the OP Stack GasPriceOracle contract + // OPStackGasOracle_l1BaseFee is a hex encoded call to: + // `function l1BaseFee() external view returns (uint256);` + OPStackGasOracle_l1BaseFee = "l1BaseFee" + // OPStackGasOracle_getL1Fee fetches the l1 fee for given tx bytes + // OPStackGasOracle_getL1Fee is a hex encoded call to: + // `function getL1Fee(bytes) external view returns (uint256);` + OPStackGasOracle_getL1Fee = "getL1Fee" + // This is the case for Optimism and Base. + OPGasOracleAddress = "0x420000000000000000000000000000000000000F" + // GasOracleAddress is the address of the precompiled contract that exists on Kroma chain. + // This is the case for Kroma. + KromaGasOracleAddress = "0x4200000000000000000000000000000000000005" + // ScrollGasOracleAddress is the address of the precompiled contract that exists on Scroll chain. + ScrollGasOracleAddress = "0x5300000000000000000000000000000000000002" +) + +func NewOpStackL1GasOracle(lggr logger.Logger, ethClient l1OracleClient, chainType config.ChainType) *OptimismL1Oracle { + var precompileAddress string + switch chainType { + case config.ChainOptimismBedrock: + precompileAddress = OPGasOracleAddress + case config.ChainKroma: + precompileAddress = KromaGasOracleAddress + case config.ChainScroll: + precompileAddress = ScrollGasOracleAddress + default: + panic(fmt.Sprintf("Received unspported chaintype %s", chainType)) + } + return newOpStackL1GasOracle(lggr, ethClient, chainType, precompileAddress) +} + +func newOpStackL1GasOracle(lggr logger.Logger, ethClient l1OracleClient, chainType config.ChainType, precompileAddress string) *OptimismL1Oracle { + var l1OracleAddress, gasPriceMethod, gasCostMethod string + var l1GasPriceMethodAbi, l1GasCostMethodAbi abi.ABI + var gasPriceErr, gasCostErr error + + l1OracleAddress = precompileAddress + gasPriceMethod = OPStackGasOracle_l1BaseFee + l1GasPriceMethodAbi, gasPriceErr = abi.JSON(strings.NewReader(L1BaseFeeAbiString)) + gasCostMethod = OPStackGasOracle_getL1Fee + l1GasCostMethodAbi, gasCostErr = abi.JSON(strings.NewReader(GetL1FeeAbiString)) + + if gasPriceErr != nil { + panic(fmt.Sprintf("Failed to parse L1 gas price method ABI for chain: %s", chainType)) + } + if gasCostErr != nil { + panic(fmt.Sprintf("Failed to parse L1 gas cost method ABI for chain: %s", chainType)) + } + + // encode calldata for each method; these calldata will remain the same for each call, we can encode them just once + l1BaseFeeMethodAbi, err := abi.JSON(strings.NewReader(L1BaseFeeAbiString)) + if err != nil { + panic(fmt.Errorf("failed to parse GasPriceOracle %s() method ABI for chain: %s; %w", OPStackGasOracle_l1BaseFee, chainType, err)) + } + l1BaseFeeCalldata, err := l1BaseFeeMethodAbi.Pack(OPStackGasOracle_l1BaseFee) + if err != nil { + panic(fmt.Errorf("failed to parse GasPriceOracle %s() calldata for chain: %s; %w", OPStackGasOracle_l1BaseFee, chainType, err)) + } + + isEcotoneMethodAbi, err := abi.JSON(strings.NewReader(OPIsEcotoneAbiString)) + if err != nil { + panic(fmt.Errorf("failed to parse GasPriceOracle %s() method ABI for chain: %s; %w", OPStackGasOracle_isEcotone, chainType, err)) + } + isEcotoneCalldata, err := isEcotoneMethodAbi.Pack(OPStackGasOracle_isEcotone) + if err != nil { + panic(fmt.Errorf("failed to parse GasPriceOracle %s() calldata for chain: %s; %w", OPStackGasOracle_isEcotone, chainType, err)) + } + + getL1GasUsedMethodAbi, err := abi.JSON(strings.NewReader(OPGetL1GasUsedAbiString)) + if err != nil { + panic(fmt.Errorf("failed to parse GasPriceOracle %s() method ABI for chain: %s; %w", OPStackGasOracle_getL1GasUsed, chainType, err)) + } + getL1GasUsedCalldata, err := getL1GasUsedMethodAbi.Pack(OPStackGasOracle_getL1GasUsed, []byte{0x1}) + if err != nil { + panic(fmt.Errorf("failed to parse GasPriceOracle %s() calldata for chain: %s; %w", OPStackGasOracle_getL1GasUsed, chainType, err)) + } + + getL1FeeMethodAbi, err := abi.JSON(strings.NewReader(GetL1FeeAbiString)) + if err != nil { + panic(fmt.Errorf("failed to parse GasPriceOracle %s() method ABI for chain: %s; %w", OPStackGasOracle_getL1Fee, chainType, err)) + } + getL1FeeCalldata, err := getL1FeeMethodAbi.Pack(OPStackGasOracle_getL1Fee, []byte{0x1}) + if err != nil { + panic(fmt.Errorf("failed to parse GasPriceOracle %s() calldata for chain: %s; %w", OPStackGasOracle_getL1Fee, chainType, err)) + } + + return &OptimismL1Oracle{ + client: ethClient, + pollPeriod: PollPeriod, + logger: logger.Sugared(logger.Named(lggr, "L1GasOracle(optimismBedrock)")), + chainType: chainType, + + l1OracleAddress: l1OracleAddress, + gasPriceMethod: gasPriceMethod, + l1GasPriceMethodAbi: l1GasPriceMethodAbi, + gasCostMethod: gasCostMethod, + l1GasCostMethodAbi: l1GasCostMethodAbi, + + chInitialised: make(chan struct{}), + chStop: make(chan struct{}), + chDone: make(chan struct{}), + + isEcotoneMethodAbi: isEcotoneMethodAbi, + + l1BaseFeeCalldata: l1BaseFeeCalldata, + isEcotoneCalldata: isEcotoneCalldata, + getL1GasUsedCalldata: getL1GasUsedCalldata, + getL1FeeCalldata: getL1FeeCalldata, + + isEcotone: false, + isEcotoneCheckTs: 0, + } +} + +func (o *OptimismL1Oracle) Name() string { + return o.logger.Name() +} + +func (o *OptimismL1Oracle) Start(ctx context.Context) error { + return o.StartOnce(o.Name(), func() error { + go o.run() + <-o.chInitialised + return nil + }) +} +func (o *OptimismL1Oracle) Close() error { + return o.StopOnce(o.Name(), func() error { + close(o.chStop) + <-o.chDone + return nil + }) +} + +func (o *OptimismL1Oracle) HealthReport() map[string]error { + return map[string]error{o.Name(): o.Healthy()} +} + +func (o *OptimismL1Oracle) run() { + defer close(o.chDone) + + t := o.refresh() + close(o.chInitialised) + + for { + select { + case <-o.chStop: + return + case <-t.C: + t = o.refresh() + } + } +} +func (o *OptimismL1Oracle) refresh() (t *time.Timer) { + t, err := o.refreshWithError() + if err != nil { + o.SvcErrBuffer.Append(err) + } + return +} + +func (o *OptimismL1Oracle) refreshWithError() (t *time.Timer, err error) { + t = time.NewTimer(utils.WithJitter(o.pollPeriod)) + + ctx, cancel := o.chStop.CtxCancel(evmclient.ContextWithDefaultTimeout()) + defer cancel() + + price, err := o.GetDAGasPrice(ctx) + if err != nil { + return t, err + } + + o.l1GasPriceMu.Lock() + defer o.l1GasPriceMu.Unlock() + o.l1GasPrice = priceEntry{price: assets.NewWei(price), timestamp: time.Now()} + return +} + +func (o *OptimismL1Oracle) GasPrice(_ context.Context) (l1GasPrice *assets.Wei, err error) { + var timestamp time.Time + ok := o.IfStarted(func() { + o.l1GasPriceMu.RLock() + l1GasPrice = o.l1GasPrice.price + timestamp = o.l1GasPrice.timestamp + o.l1GasPriceMu.RUnlock() + }) + if !ok { + return l1GasPrice, fmt.Errorf("L1GasOracle is not started; cannot estimate gas") + } + if l1GasPrice == nil { + return l1GasPrice, fmt.Errorf("failed to get l1 gas price; gas price not set") + } + // Validate the price has been updated within the pollPeriod * 2 + // Allowing double the poll period before declaring the price stale to give ample time for the refresh to process + if time.Since(timestamp) > o.pollPeriod*2 { + return l1GasPrice, fmt.Errorf("gas price is stale") + } + return +} + +// Gets the L1 gas cost for the provided transaction at the specified block num +// If block num is not provided, the value on the latest block num is used +func (o *OptimismL1Oracle) GetGasCost(ctx context.Context, tx *gethtypes.Transaction, blockNum *big.Int) (*assets.Wei, error) { + ctx, cancel := context.WithTimeout(ctx, client.QueryTimeout) + defer cancel() + var callData, b []byte + var err error + if o.chainType == config.ChainKroma { + return nil, fmt.Errorf("L1 gas cost not supported for this chain: %s", o.chainType) + } + // Append rlp-encoded tx + var encodedtx []byte + if encodedtx, err = tx.MarshalBinary(); err != nil { + return nil, fmt.Errorf("failed to marshal tx for gas cost estimation: %w", err) + } + if callData, err = o.l1GasCostMethodAbi.Pack(o.gasCostMethod, encodedtx); err != nil { + return nil, fmt.Errorf("failed to pack calldata for %s L1 gas cost estimation method: %w", o.chainType, err) + } + + precompile := common.HexToAddress(o.l1OracleAddress) + b, err = o.client.CallContract(ctx, ethereum.CallMsg{ + To: &precompile, + Data: callData, + }, blockNum) + if err != nil { + errorMsg := fmt.Sprintf("gas oracle contract call failed: %v", err) + o.logger.Errorf(errorMsg) + return nil, fmt.Errorf(errorMsg) + } + + var l1GasCost *big.Int + if len(b) != 32 { // returns uint256; + errorMsg := fmt.Sprintf("return data length (%d) different than expected (%d)", len(b), 32) + o.logger.Critical(errorMsg) + return nil, fmt.Errorf(errorMsg) + } + l1GasCost = new(big.Int).SetBytes(b) + + return assets.NewWei(l1GasCost), nil +} + +func (o *OptimismL1Oracle) GetDAGasPrice(ctx context.Context) (*big.Int, error) { + isEcotone, err := o.checkIsEcotone(ctx) + if err != nil { + return nil, err + } + + o.logger.Infof("Chain isEcotone result: %t", isEcotone) + + if isEcotone { + return o.getEcotoneGasPrice(ctx) + } + + return o.getV1GasPrice(ctx) +} + +func (o *OptimismL1Oracle) checkIsEcotone(ctx context.Context) (bool, error) { + // if chain is already Ecotone, NOOP + if o.isEcotone { + return true, nil + } + // if time since last check has not exceeded polling period, NOOP + if time.Now().Unix()-o.isEcotoneCheckTs < OPStackGasOracle_isEcotonePollingPeriod { + return false, nil + } + o.isEcotoneCheckTs = time.Now().Unix() + + l1OracleAddress := common.HexToAddress(o.l1OracleAddress) + // confirmed with OP team that isEcotone() is the canonical way to check if the chain has upgraded + b, err := o.client.CallContract(ctx, ethereum.CallMsg{ + To: &l1OracleAddress, + Data: o.isEcotoneCalldata, + }, nil) + + // if the chain has not upgraded to Ecotone, the isEcotone call will revert, this would be expected + if err != nil { + o.logger.Infof("isEcotone() call failed, this can happen if chain has not upgraded: %w", err) + return false, nil + } + + res, err := o.isEcotoneMethodAbi.Unpack(OPStackGasOracle_isEcotone, b) + if err != nil { + return false, fmt.Errorf("failed to unpack isEcotone() return data: %w", err) + } + o.isEcotone = res[0].(bool) + return o.isEcotone, nil +} + +func (o *OptimismL1Oracle) getV1GasPrice(ctx context.Context) (*big.Int, error) { + l1OracleAddress := common.HexToAddress(o.l1OracleAddress) + b, err := o.client.CallContract(ctx, ethereum.CallMsg{ + To: &l1OracleAddress, + Data: o.l1BaseFeeCalldata, + }, nil) + if err != nil { + return nil, fmt.Errorf("l1BaseFee() call failed: %w", err) + } + + if len(b) != 32 { + return nil, fmt.Errorf("l1BaseFee() return data length (%d) different than expected (%d)", len(b), 32) + } + return new(big.Int).SetBytes(b), nil +} + +func (o *OptimismL1Oracle) getEcotoneGasPrice(ctx context.Context) (*big.Int, error) { + rpcBatchCalls := []rpc.BatchElem{ + { + Method: "eth_call", + Args: []any{ + map[string]interface{}{ + "from": common.Address{}, + "to": o.l1OracleAddress, + "data": hexutil.Bytes(o.getL1GasUsedCalldata), + }, + "latest", + }, + Result: new(string), + }, + { + Method: "eth_call", + Args: []any{ + map[string]interface{}{ + "from": common.Address{}, + "to": o.l1OracleAddress, + "data": hexutil.Bytes(o.getL1FeeCalldata), + }, + "latest", + }, + Result: new(string), + }, + } + + err := o.client.BatchCallContext(ctx, rpcBatchCalls) + if err != nil { + return nil, fmt.Errorf("getEcotoneGasPrice batch call failed: %w", err) + } + if rpcBatchCalls[0].Error != nil { + return nil, fmt.Errorf("%s call failed in a batch: %w", OPStackGasOracle_getL1GasUsed, err) + } + if rpcBatchCalls[1].Error != nil { + return nil, fmt.Errorf("%s call failed in a batch: %w", OPStackGasOracle_getL1Fee, err) + } + + l1GasUsedResult := *(rpcBatchCalls[0].Result.(*string)) + l1FeeResult := *(rpcBatchCalls[1].Result.(*string)) + + l1GasUsedBytes, err := hexutil.Decode(l1GasUsedResult) + if err != nil { + return nil, fmt.Errorf("failed to decode %s rpc result: %w", OPStackGasOracle_getL1GasUsed, err) + } + l1FeeBytes, err := hexutil.Decode(l1FeeResult) + if err != nil { + return nil, fmt.Errorf("failed to decode %s rpc result: %w", OPStackGasOracle_getL1Fee, err) + } + + l1GasUsed := new(big.Int).SetBytes(l1GasUsedBytes) + l1Fee := new(big.Int).SetBytes(l1FeeBytes) + + // for the same tx byte, l1Fee / l1GasUsed will give the l1 gas price + // note this price is per l1 gas, not l1 data byte + return new(big.Int).Div(l1Fee, l1GasUsed), nil +} diff --git a/core/chains/evm/gas/rollups/op_price_reader_test.go b/core/chains/evm/gas/rollups/op_l1_oracle_test.go similarity index 90% rename from core/chains/evm/gas/rollups/op_price_reader_test.go rename to core/chains/evm/gas/rollups/op_l1_oracle_test.go index dad12a16366..36e8700faff 100644 --- a/core/chains/evm/gas/rollups/op_price_reader_test.go +++ b/core/chains/evm/gas/rollups/op_l1_oracle_test.go @@ -60,7 +60,7 @@ func TestDAPriceReader_ReadV1GasPrice(t *testing.T) { isEcotoneCalldata, err := isEcotoneMethodAbi.Pack(OPStackGasOracle_isEcotone) require.NoError(t, err) - ethClient := mocks.NewETHClient(t) + ethClient := mocks.NewL1OracleClient(t) call := ethClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { callMsg := args.Get(1).(ethereum.CallMsg) blockNumber := args.Get(2).(*big.Int) @@ -87,7 +87,7 @@ func TestDAPriceReader_ReadV1GasPrice(t *testing.T) { }).Return(common.BigToHash(l1BaseFee).Bytes(), nil).Once() } - oracle := newOPPriceReader(logger.Test(t), ethClient, config.ChainOptimismBedrock, oracleAddress) + oracle := newOpStackL1GasOracle(logger.Test(t), ethClient, config.ChainOptimismBedrock, oracleAddress) gasPrice, err := oracle.GetDAGasPrice(testutils.Context(t)) if tc.returnBadData { @@ -100,13 +100,13 @@ func TestDAPriceReader_ReadV1GasPrice(t *testing.T) { } } -func setupIsEcotone(t *testing.T, oracleAddress string) *mocks.ETHClient { +func setupIsEcotone(t *testing.T, oracleAddress string) *mocks.L1OracleClient { isEcotoneMethodAbi, err := abi.JSON(strings.NewReader(OPIsEcotoneAbiString)) require.NoError(t, err) isEcotoneCalldata, err := isEcotoneMethodAbi.Pack(OPStackGasOracle_isEcotone) require.NoError(t, err) - ethClient := mocks.NewETHClient(t) + ethClient := mocks.NewL1OracleClient(t) ethClient.On("CallContract", mock.Anything, mock.IsType(ethereum.CallMsg{}), mock.IsType(&big.Int{})).Run(func(args mock.Arguments) { callMsg := args.Get(1).(ethereum.CallMsg) blockNumber := args.Get(2).(*big.Int) @@ -142,7 +142,7 @@ func TestDAPriceReader_ReadEcotoneGasPrice(t *testing.T) { for _, rE := range rpcElements { require.Equal(t, "eth_call", rE.Method) - require.Equal(t, oracleAddress, rE.Args[0].(map[string]interface{})["to"].(common.Address).String()) + require.Equal(t, oracleAddress, rE.Args[0].(map[string]interface{})["to"]) require.Equal(t, "latest", rE.Args[1]) } @@ -155,7 +155,7 @@ func TestDAPriceReader_ReadEcotoneGasPrice(t *testing.T) { rpcElements[1].Result = &res2 }).Return(nil).Once() - oracle := newOPPriceReader(logger.Test(t), ethClient, config.ChainOptimismBedrock, oracleAddress) + oracle := newOpStackL1GasOracle(logger.Test(t), ethClient, config.ChainOptimismBedrock, oracleAddress) gasPrice, err := oracle.GetDAGasPrice(testutils.Context(t)) require.NoError(t, err) assert.Equal(t, l1BaseFee, gasPrice) @@ -170,7 +170,7 @@ func TestDAPriceReader_ReadEcotoneGasPrice(t *testing.T) { rpcElements[1].Result = &badData }).Return(nil).Once() - oracle := newOPPriceReader(logger.Test(t), ethClient, config.ChainOptimismBedrock, oracleAddress) + oracle := newOpStackL1GasOracle(logger.Test(t), ethClient, config.ChainOptimismBedrock, oracleAddress) _, err := oracle.GetDAGasPrice(testutils.Context(t)) assert.Error(t, err) }) @@ -179,7 +179,7 @@ func TestDAPriceReader_ReadEcotoneGasPrice(t *testing.T) { ethClient := setupIsEcotone(t, oracleAddress) ethClient.On("BatchCallContext", mock.Anything, mock.IsType([]rpc.BatchElem{})).Return(fmt.Errorf("revert")).Once() - oracle := newOPPriceReader(logger.Test(t), ethClient, config.ChainOptimismBedrock, oracleAddress) + oracle := newOpStackL1GasOracle(logger.Test(t), ethClient, config.ChainOptimismBedrock, oracleAddress) _, err := oracle.GetDAGasPrice(testutils.Context(t)) assert.Error(t, err) }) @@ -193,7 +193,7 @@ func TestDAPriceReader_ReadEcotoneGasPrice(t *testing.T) { rpcElements[1].Error = fmt.Errorf("revert") }).Return(nil).Once() - oracle := newOPPriceReader(logger.Test(t), ethClient, config.ChainOptimismBedrock, oracleAddress) + oracle := newOpStackL1GasOracle(logger.Test(t), ethClient, config.ChainOptimismBedrock, oracleAddress) _, err := oracle.GetDAGasPrice(testutils.Context(t)) assert.Error(t, err) }) diff --git a/core/chains/evm/gas/rollups/op_price_reader.go b/core/chains/evm/gas/rollups/op_price_reader.go deleted file mode 100644 index 2d3d668ad8b..00000000000 --- a/core/chains/evm/gas/rollups/op_price_reader.go +++ /dev/null @@ -1,228 +0,0 @@ -package rollups - -import ( - "context" - "fmt" - "math/big" - "strings" - "time" - - "github.com/ethereum/go-ethereum" - "github.com/ethereum/go-ethereum/accounts/abi" - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/common/hexutil" - "github.com/ethereum/go-ethereum/rpc" - - "github.com/smartcontractkit/chainlink-common/pkg/logger" - - "github.com/smartcontractkit/chainlink/v2/common/config" -) - -const ( - // OPStackGasOracle_l1BaseFee fetches the l1 base fee set in the OP Stack GasPriceOracle contract - OPStackGasOracle_l1BaseFee = "l1BaseFee" - - // OPStackGasOracle_isEcotone fetches if the OP Stack GasPriceOracle contract has upgraded to Ecotone - OPStackGasOracle_isEcotone = "isEcotone" - - // OPStackGasOracle_getL1GasUsed fetches the l1 gas used for given tx bytes - OPStackGasOracle_getL1GasUsed = "getL1GasUsed" - - // OPStackGasOracle_getL1Fee fetches the l1 fee for given tx bytes - OPStackGasOracle_getL1Fee = "getL1Fee" - - // OPStackGasOracle_isEcotonePollingPeriod is the interval to poll if chain has upgraded to Ecotone - // Set to poll every 4 hours - OPStackGasOracle_isEcotonePollingPeriod = 14400 -) - -type opStackGasPriceReader struct { - client ethClient - logger logger.SugaredLogger - - oracleAddress common.Address - isEcotoneMethodAbi abi.ABI - - l1BaseFeeCalldata []byte - isEcotoneCalldata []byte - getL1GasUsedCalldata []byte - getL1FeeCalldata []byte - - isEcotone bool - isEcotoneCheckTs int64 -} - -func newOPPriceReader(lggr logger.Logger, ethClient ethClient, chainType config.ChainType, oracleAddress string) daPriceReader { - // encode calldata for each method; these calldata will remain the same for each call, we can encode them just once - l1BaseFeeMethodAbi, err := abi.JSON(strings.NewReader(L1BaseFeeAbiString)) - if err != nil { - panic(fmt.Errorf("failed to parse GasPriceOracle %s() method ABI for chain: %s; %w", OPStackGasOracle_l1BaseFee, chainType, err)) - } - l1BaseFeeCalldata, err := l1BaseFeeMethodAbi.Pack(OPStackGasOracle_l1BaseFee) - if err != nil { - panic(fmt.Errorf("failed to parse GasPriceOracle %s() calldata for chain: %s; %w", OPStackGasOracle_l1BaseFee, chainType, err)) - } - - isEcotoneMethodAbi, err := abi.JSON(strings.NewReader(OPIsEcotoneAbiString)) - if err != nil { - panic(fmt.Errorf("failed to parse GasPriceOracle %s() method ABI for chain: %s; %w", OPStackGasOracle_isEcotone, chainType, err)) - } - isEcotoneCalldata, err := isEcotoneMethodAbi.Pack(OPStackGasOracle_isEcotone) - if err != nil { - panic(fmt.Errorf("failed to parse GasPriceOracle %s() calldata for chain: %s; %w", OPStackGasOracle_isEcotone, chainType, err)) - } - - getL1GasUsedMethodAbi, err := abi.JSON(strings.NewReader(OPGetL1GasUsedAbiString)) - if err != nil { - panic(fmt.Errorf("failed to parse GasPriceOracle %s() method ABI for chain: %s; %w", OPStackGasOracle_getL1GasUsed, chainType, err)) - } - getL1GasUsedCalldata, err := getL1GasUsedMethodAbi.Pack(OPStackGasOracle_getL1GasUsed, []byte{0x1}) - if err != nil { - panic(fmt.Errorf("failed to parse GasPriceOracle %s() calldata for chain: %s; %w", OPStackGasOracle_getL1GasUsed, chainType, err)) - } - - getL1FeeMethodAbi, err := abi.JSON(strings.NewReader(GetL1FeeAbiString)) - if err != nil { - panic(fmt.Errorf("failed to parse GasPriceOracle %s() method ABI for chain: %s; %w", OPStackGasOracle_getL1Fee, chainType, err)) - } - getL1FeeCalldata, err := getL1FeeMethodAbi.Pack(OPStackGasOracle_getL1Fee, []byte{0x1}) - if err != nil { - panic(fmt.Errorf("failed to parse GasPriceOracle %s() calldata for chain: %s; %w", OPStackGasOracle_getL1Fee, chainType, err)) - } - - return &opStackGasPriceReader{ - client: ethClient, - logger: logger.Sugared(logger.Named(lggr, fmt.Sprintf("OPStackGasOracle(%s)", chainType))), - - oracleAddress: common.HexToAddress(oracleAddress), - isEcotoneMethodAbi: isEcotoneMethodAbi, - - l1BaseFeeCalldata: l1BaseFeeCalldata, - isEcotoneCalldata: isEcotoneCalldata, - getL1GasUsedCalldata: getL1GasUsedCalldata, - getL1FeeCalldata: getL1FeeCalldata, - - isEcotone: false, - isEcotoneCheckTs: 0, - } -} - -func (o *opStackGasPriceReader) GetDAGasPrice(ctx context.Context) (*big.Int, error) { - isEcotone, err := o.checkIsEcotone(ctx) - if err != nil { - return nil, err - } - - o.logger.Infof("Chain isEcotone result: %t", isEcotone) - - if isEcotone { - return o.getEcotoneGasPrice(ctx) - } - - return o.getV1GasPrice(ctx) -} - -func (o *opStackGasPriceReader) checkIsEcotone(ctx context.Context) (bool, error) { - // if chain is already Ecotone, NOOP - if o.isEcotone { - return true, nil - } - // if time since last check has not exceeded polling period, NOOP - if time.Now().Unix()-o.isEcotoneCheckTs < OPStackGasOracle_isEcotonePollingPeriod { - return false, nil - } - o.isEcotoneCheckTs = time.Now().Unix() - - // confirmed with OP team that isEcotone() is the canonical way to check if the chain has upgraded - b, err := o.client.CallContract(ctx, ethereum.CallMsg{ - To: &o.oracleAddress, - Data: o.isEcotoneCalldata, - }, nil) - - // if the chain has not upgraded to Ecotone, the isEcotone call will revert, this would be expected - if err != nil { - o.logger.Infof("isEcotone() call failed, this can happen if chain has not upgraded: %w", err) - return false, nil - } - - res, err := o.isEcotoneMethodAbi.Unpack(OPStackGasOracle_isEcotone, b) - if err != nil { - return false, fmt.Errorf("failed to unpack isEcotone() return data: %w", err) - } - o.isEcotone = res[0].(bool) - return o.isEcotone, nil -} - -func (o *opStackGasPriceReader) getV1GasPrice(ctx context.Context) (*big.Int, error) { - b, err := o.client.CallContract(ctx, ethereum.CallMsg{ - To: &o.oracleAddress, - Data: o.l1BaseFeeCalldata, - }, nil) - if err != nil { - return nil, fmt.Errorf("l1BaseFee() call failed: %w", err) - } - - if len(b) != 32 { - return nil, fmt.Errorf("l1BaseFee() return data length (%d) different than expected (%d)", len(b), 32) - } - return new(big.Int).SetBytes(b), nil -} - -func (o *opStackGasPriceReader) getEcotoneGasPrice(ctx context.Context) (*big.Int, error) { - rpcBatchCalls := []rpc.BatchElem{ - { - Method: "eth_call", - Args: []any{ - map[string]interface{}{ - "from": common.Address{}, - "to": o.oracleAddress, - "data": hexutil.Bytes(o.getL1GasUsedCalldata), - }, - "latest", - }, - Result: new(string), - }, - { - Method: "eth_call", - Args: []any{ - map[string]interface{}{ - "from": common.Address{}, - "to": o.oracleAddress, - "data": hexutil.Bytes(o.getL1FeeCalldata), - }, - "latest", - }, - Result: new(string), - }, - } - - err := o.client.BatchCallContext(ctx, rpcBatchCalls) - if err != nil { - return nil, fmt.Errorf("getEcotoneGasPrice batch call failed: %w", err) - } - if rpcBatchCalls[0].Error != nil { - return nil, fmt.Errorf("%s call failed in a batch: %w", OPStackGasOracle_getL1GasUsed, err) - } - if rpcBatchCalls[1].Error != nil { - return nil, fmt.Errorf("%s call failed in a batch: %w", OPStackGasOracle_getL1Fee, err) - } - - l1GasUsedResult := *(rpcBatchCalls[0].Result.(*string)) - l1FeeResult := *(rpcBatchCalls[1].Result.(*string)) - - l1GasUsedBytes, err := hexutil.Decode(l1GasUsedResult) - if err != nil { - return nil, fmt.Errorf("failed to decode %s rpc result: %w", OPStackGasOracle_getL1GasUsed, err) - } - l1FeeBytes, err := hexutil.Decode(l1FeeResult) - if err != nil { - return nil, fmt.Errorf("failed to decode %s rpc result: %w", OPStackGasOracle_getL1Fee, err) - } - - l1GasUsed := new(big.Int).SetBytes(l1GasUsedBytes) - l1Fee := new(big.Int).SetBytes(l1FeeBytes) - - // for the same tx byte, l1Fee / l1GasUsed will give the l1 gas price - // note this price is per l1 gas, not l1 data byte - return new(big.Int).Div(l1Fee, l1GasUsed), nil -} diff --git a/core/chains/evm/gas/suggested_price_estimator.go b/core/chains/evm/gas/suggested_price_estimator.go index edc1b0f92fa..e947e9109d1 100644 --- a/core/chains/evm/gas/suggested_price_estimator.go +++ b/core/chains/evm/gas/suggested_price_estimator.go @@ -19,6 +19,7 @@ import ( feetypes "github.com/smartcontractkit/chainlink/v2/common/fee/types" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" evmclient "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas/rollups" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" ) @@ -31,8 +32,7 @@ type suggestedPriceConfig interface { BumpMin() *assets.Wei } -//go:generate mockery --quiet --name rpcClient --output ./mocks/ --case=underscore --structname RPCClient -type rpcClient interface { +type suggestedPriceEstimatorClient interface { CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error } @@ -41,7 +41,7 @@ type SuggestedPriceEstimator struct { services.StateMachine cfg suggestedPriceConfig - client rpcClient + client suggestedPriceEstimatorClient pollPeriod time.Duration logger logger.Logger @@ -52,10 +52,12 @@ type SuggestedPriceEstimator struct { chInitialised chan struct{} chStop services.StopChan chDone chan struct{} + + l1Oracle rollups.L1Oracle } // NewSuggestedPriceEstimator returns a new Estimator which uses the suggested gas price. -func NewSuggestedPriceEstimator(lggr logger.Logger, client rpcClient, cfg suggestedPriceConfig) EvmEstimator { +func NewSuggestedPriceEstimator(lggr logger.Logger, client feeEstimatorClient, cfg suggestedPriceConfig, l1Oracle rollups.L1Oracle) EvmEstimator { return &SuggestedPriceEstimator{ client: client, pollPeriod: 10 * time.Second, @@ -65,6 +67,7 @@ func NewSuggestedPriceEstimator(lggr logger.Logger, client rpcClient, cfg sugges chInitialised: make(chan struct{}), chStop: make(chan struct{}), chDone: make(chan struct{}), + l1Oracle: l1Oracle, } } @@ -72,6 +75,10 @@ func (o *SuggestedPriceEstimator) Name() string { return o.logger.Name() } +func (o *SuggestedPriceEstimator) L1Oracle() rollups.L1Oracle { + return o.l1Oracle +} + func (o *SuggestedPriceEstimator) Start(context.Context) error { return o.StartOnce("SuggestedPriceEstimator", func() error { go o.run() diff --git a/core/chains/evm/gas/suggested_price_estimator_test.go b/core/chains/evm/gas/suggested_price_estimator_test.go index 0d52d6ab1b9..4f3c4d307d6 100644 --- a/core/chains/evm/gas/suggested_price_estimator_test.go +++ b/core/chains/evm/gas/suggested_price_estimator_test.go @@ -15,6 +15,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas/mocks" + rollupMocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/gas/rollups/mocks" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" ) @@ -29,20 +30,24 @@ func TestSuggestedPriceEstimator(t *testing.T) { cfg := &gas.MockGasEstimatorConfig{BumpPercentF: 10, BumpMinF: assets.NewWei(big.NewInt(1)), BumpThresholdF: 1} t.Run("calling GetLegacyGas on unstarted estimator returns error", func(t *testing.T) { - client := mocks.NewRPCClient(t) - o := gas.NewSuggestedPriceEstimator(logger.Test(t), client, cfg) + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + + o := gas.NewSuggestedPriceEstimator(logger.Test(t), feeEstimatorClient, cfg, l1Oracle) _, _, err := o.GetLegacyGas(testutils.Context(t), calldata, gasLimit, maxGasPrice) assert.EqualError(t, err, "estimator is not started") }) t.Run("calling GetLegacyGas on started estimator returns prices", func(t *testing.T) { - client := mocks.NewRPCClient(t) - client.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + + feeEstimatorClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { res := args.Get(1).(*hexutil.Big) (*big.Int)(res).SetInt64(42) }) - o := gas.NewSuggestedPriceEstimator(logger.Test(t), client, cfg) + o := gas.NewSuggestedPriceEstimator(logger.Test(t), feeEstimatorClient, cfg, l1Oracle) servicetest.RunHealthy(t, o) gasPrice, chainSpecificGasLimit, err := o.GetLegacyGas(testutils.Context(t), calldata, gasLimit, maxGasPrice) require.NoError(t, err) @@ -51,10 +56,12 @@ func TestSuggestedPriceEstimator(t *testing.T) { }) t.Run("gas price is lower than user specified max gas price", func(t *testing.T) { - client := mocks.NewRPCClient(t) - o := gas.NewSuggestedPriceEstimator(logger.Test(t), client, cfg) + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollupMocks.NewL1Oracle(t) - client.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { + o := gas.NewSuggestedPriceEstimator(logger.Test(t), feeEstimatorClient, cfg, l1Oracle) + + feeEstimatorClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { res := args.Get(1).(*hexutil.Big) (*big.Int)(res).SetInt64(42) }) @@ -68,10 +75,12 @@ func TestSuggestedPriceEstimator(t *testing.T) { }) t.Run("gas price is lower than global max gas price", func(t *testing.T) { - client := mocks.NewRPCClient(t) - o := gas.NewSuggestedPriceEstimator(logger.Test(t), client, cfg) + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + + o := gas.NewSuggestedPriceEstimator(logger.Test(t), feeEstimatorClient, cfg, l1Oracle) - client.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { + feeEstimatorClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { res := args.Get(1).(*hexutil.Big) (*big.Int)(res).SetInt64(120) }) @@ -84,10 +93,12 @@ func TestSuggestedPriceEstimator(t *testing.T) { }) t.Run("calling GetLegacyGas on started estimator if initial call failed returns error", func(t *testing.T) { - client := mocks.NewRPCClient(t) - o := gas.NewSuggestedPriceEstimator(logger.Test(t), client, cfg) + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollupMocks.NewL1Oracle(t) - client.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(pkgerrors.New("kaboom")) + o := gas.NewSuggestedPriceEstimator(logger.Test(t), feeEstimatorClient, cfg, l1Oracle) + + feeEstimatorClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(pkgerrors.New("kaboom")) servicetest.RunHealthy(t, o) @@ -96,22 +107,28 @@ func TestSuggestedPriceEstimator(t *testing.T) { }) t.Run("calling GetDynamicFee always returns error", func(t *testing.T) { - client := mocks.NewRPCClient(t) - o := gas.NewSuggestedPriceEstimator(logger.Test(t), client, cfg) + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + + o := gas.NewSuggestedPriceEstimator(logger.Test(t), feeEstimatorClient, cfg, l1Oracle) _, err := o.GetDynamicFee(testutils.Context(t), maxGasPrice) assert.EqualError(t, err, "dynamic fees are not implemented for this estimator") }) t.Run("calling BumpLegacyGas on unstarted estimator returns error", func(t *testing.T) { - client := mocks.NewRPCClient(t) - o := gas.NewSuggestedPriceEstimator(logger.Test(t), client, cfg) + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + + o := gas.NewSuggestedPriceEstimator(logger.Test(t), feeEstimatorClient, cfg, l1Oracle) _, _, err := o.BumpLegacyGas(testutils.Context(t), assets.NewWeiI(42), gasLimit, maxGasPrice, nil) assert.EqualError(t, err, "estimator is not started") }) t.Run("calling BumpDynamicFee always returns error", func(t *testing.T) { - client := mocks.NewRPCClient(t) - o := gas.NewSuggestedPriceEstimator(logger.Test(t), client, cfg) + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + + o := gas.NewSuggestedPriceEstimator(logger.Test(t), feeEstimatorClient, cfg, l1Oracle) fee := gas.DynamicFee{ FeeCap: assets.NewWeiI(42), TipCap: assets.NewWeiI(5), @@ -121,13 +138,15 @@ func TestSuggestedPriceEstimator(t *testing.T) { }) t.Run("calling BumpLegacyGas on started estimator returns new price buffered with bumpPercent", func(t *testing.T) { - client := mocks.NewRPCClient(t) - client.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + + feeEstimatorClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { res := args.Get(1).(*hexutil.Big) (*big.Int)(res).SetInt64(40) }) - o := gas.NewSuggestedPriceEstimator(logger.Test(t), client, cfg) + o := gas.NewSuggestedPriceEstimator(logger.Test(t), feeEstimatorClient, cfg, l1Oracle) servicetest.RunHealthy(t, o) gasPrice, chainSpecificGasLimit, err := o.BumpLegacyGas(testutils.Context(t), assets.NewWeiI(10), gasLimit, maxGasPrice, nil) require.NoError(t, err) @@ -136,14 +155,16 @@ func TestSuggestedPriceEstimator(t *testing.T) { }) t.Run("calling BumpLegacyGas on started estimator returns new price buffered with bumpMin", func(t *testing.T) { - client := mocks.NewRPCClient(t) - client.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + + feeEstimatorClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { res := args.Get(1).(*hexutil.Big) (*big.Int)(res).SetInt64(40) }) - testCfg := &gas.MockGasEstimatorConfig{BumpPercentF: 1, BumpMinF: assets.NewWei(big.NewInt(1)), BumpThresholdF: 1} - o := gas.NewSuggestedPriceEstimator(logger.Test(t), client, testCfg) + testCfg := &gas.MockGasEstimatorConfig{BumpPercentF: 1, BumpMinF: assets.NewWei(big.NewInt(1)), BumpThresholdF: 1, LimitMultiplierF: 1} + o := gas.NewSuggestedPriceEstimator(logger.Test(t), feeEstimatorClient, testCfg, l1Oracle) servicetest.RunHealthy(t, o) gasPrice, chainSpecificGasLimit, err := o.BumpLegacyGas(testutils.Context(t), assets.NewWeiI(10), gasLimit, maxGasPrice, nil) require.NoError(t, err) @@ -152,13 +173,15 @@ func TestSuggestedPriceEstimator(t *testing.T) { }) t.Run("calling BumpLegacyGas on started estimator returns original price when lower than previous", func(t *testing.T) { - client := mocks.NewRPCClient(t) - client.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + + feeEstimatorClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { res := args.Get(1).(*hexutil.Big) (*big.Int)(res).SetInt64(5) }) - o := gas.NewSuggestedPriceEstimator(logger.Test(t), client, cfg) + o := gas.NewSuggestedPriceEstimator(logger.Test(t), feeEstimatorClient, cfg, l1Oracle) servicetest.RunHealthy(t, o) gasPrice, chainSpecificGasLimit, err := o.BumpLegacyGas(testutils.Context(t), assets.NewWeiI(10), gasLimit, maxGasPrice, nil) require.NoError(t, err) @@ -167,10 +190,12 @@ func TestSuggestedPriceEstimator(t *testing.T) { }) t.Run("calling BumpLegacyGas on started estimator returns error, suggested gas price is higher than max gas price", func(t *testing.T) { - client := mocks.NewRPCClient(t) - o := gas.NewSuggestedPriceEstimator(logger.Test(t), client, cfg) + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + + o := gas.NewSuggestedPriceEstimator(logger.Test(t), feeEstimatorClient, cfg, l1Oracle) - client.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { + feeEstimatorClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { res := args.Get(1).(*hexutil.Big) (*big.Int)(res).SetInt64(42) }) @@ -184,10 +209,12 @@ func TestSuggestedPriceEstimator(t *testing.T) { }) t.Run("calling BumpLegacyGas on started estimator returns max gas price when suggested price under max but the buffer exceeds it", func(t *testing.T) { - client := mocks.NewRPCClient(t) - o := gas.NewSuggestedPriceEstimator(logger.Test(t), client, cfg) + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollupMocks.NewL1Oracle(t) - client.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { + o := gas.NewSuggestedPriceEstimator(logger.Test(t), feeEstimatorClient, cfg, l1Oracle) + + feeEstimatorClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { res := args.Get(1).(*hexutil.Big) (*big.Int)(res).SetInt64(39) }) @@ -200,10 +227,12 @@ func TestSuggestedPriceEstimator(t *testing.T) { }) t.Run("calling BumpLegacyGas on started estimator if initial call failed returns error", func(t *testing.T) { - client := mocks.NewRPCClient(t) - o := gas.NewSuggestedPriceEstimator(logger.Test(t), client, cfg) + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + + o := gas.NewSuggestedPriceEstimator(logger.Test(t), feeEstimatorClient, cfg, l1Oracle) - client.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(pkgerrors.New("kaboom")) + feeEstimatorClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(pkgerrors.New("kaboom")) servicetest.RunHealthy(t, o) @@ -212,14 +241,16 @@ func TestSuggestedPriceEstimator(t *testing.T) { }) t.Run("calling BumpLegacyGas on started estimator if refresh call failed returns price from previous update", func(t *testing.T) { - client := mocks.NewRPCClient(t) - o := gas.NewSuggestedPriceEstimator(logger.Test(t), client, cfg) + feeEstimatorClient := mocks.NewFeeEstimatorClient(t) + l1Oracle := rollupMocks.NewL1Oracle(t) + + o := gas.NewSuggestedPriceEstimator(logger.Test(t), feeEstimatorClient, cfg, l1Oracle) - client.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { + feeEstimatorClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(nil).Run(func(args mock.Arguments) { res := args.Get(1).(*hexutil.Big) (*big.Int)(res).SetInt64(40) }).Once() - client.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(pkgerrors.New("kaboom")) + feeEstimatorClient.On("CallContext", mock.Anything, mock.Anything, "eth_gasPrice").Return(pkgerrors.New("kaboom")) servicetest.RunHealthy(t, o) diff --git a/core/chains/evm/txmgr/broadcaster_test.go b/core/chains/evm/txmgr/broadcaster_test.go index 3500002e8da..070c2c37473 100644 --- a/core/chains/evm/txmgr/broadcaster_test.go +++ b/core/chains/evm/txmgr/broadcaster_test.go @@ -64,9 +64,9 @@ func NewTestEthBroadcaster( lggr := logger.Test(t) ge := config.EVM().GasEstimator() - estimator := gas.NewWrappedEvmEstimator(lggr, func(lggr logger.Logger) gas.EvmEstimator { - return gas.NewFixedPriceEstimator(config.EVM().GasEstimator(), ge.BlockHistory(), lggr) - }, ge.EIP1559DynamicFees(), nil, ge) + estimator := gas.NewEvmFeeEstimator(lggr, func(lggr logger.Logger) gas.EvmEstimator { + return gas.NewFixedPriceEstimator(config.EVM().GasEstimator(), nil, ge.BlockHistory(), lggr, nil) + }, ge.EIP1559DynamicFees(), ge) txBuilder := txmgr.NewEvmTxAttemptBuilder(*ethClient.ConfiguredChainID(), ge, keyStore, estimator) ethBroadcaster := txmgrcommon.NewBroadcaster(txStore, txmgr.NewEvmTxmClient(ethClient), txmgr.NewEvmTxmConfig(config.EVM()), txmgr.NewEvmTxmFeeConfig(config.EVM().GasEstimator()), config.EVM().Transactions(), config.Database().Listener(), keyStore, txBuilder, nonceTracker, lggr, checkerFactory, nonceAutoSync) @@ -1152,9 +1152,9 @@ func TestEthBroadcaster_ProcessUnstartedEthTxs_Errors(t *testing.T) { // same as the parent test, but callback is set by ctor t.Run("callback set by ctor", func(t *testing.T) { - estimator := gas.NewWrappedEvmEstimator(lggr, func(lggr logger.Logger) gas.EvmEstimator { - return gas.NewFixedPriceEstimator(evmcfg.EVM().GasEstimator(), evmcfg.EVM().GasEstimator().BlockHistory(), lggr) - }, evmcfg.EVM().GasEstimator().EIP1559DynamicFees(), nil, evmcfg.EVM().GasEstimator()) + estimator := gas.NewEvmFeeEstimator(lggr, func(lggr logger.Logger) gas.EvmEstimator { + return gas.NewFixedPriceEstimator(evmcfg.EVM().GasEstimator(), nil, evmcfg.EVM().GasEstimator().BlockHistory(), lggr, nil) + }, evmcfg.EVM().GasEstimator().EIP1559DynamicFees(), evmcfg.EVM().GasEstimator()) txBuilder := txmgr.NewEvmTxAttemptBuilder(*ethClient.ConfiguredChainID(), evmcfg.EVM().GasEstimator(), ethKeyStore, estimator) localNextNonce = getLocalNextNonce(t, nonceTracker, fromAddress) eb2 := txmgr.NewEvmBroadcaster(txStore, txmClient, txmgr.NewEvmTxmConfig(evmcfg.EVM()), txmgr.NewEvmTxmFeeConfig(evmcfg.EVM().GasEstimator()), evmcfg.EVM().Transactions(), evmcfg.Database().Listener(), ethKeyStore, txBuilder, lggr, &testCheckerFactory{}, false) @@ -1738,9 +1738,9 @@ func TestEthBroadcaster_SyncNonce(t *testing.T) { kst := cltest.NewKeyStore(t, db, cfg.Database()).Eth() _, fromAddress := cltest.RandomKey{Disabled: false}.MustInsertWithState(t, kst) - estimator := gas.NewWrappedEvmEstimator(lggr, func(lggr logger.Logger) gas.EvmEstimator { - return gas.NewFixedPriceEstimator(evmcfg.EVM().GasEstimator(), evmcfg.EVM().GasEstimator().BlockHistory(), lggr) - }, evmcfg.EVM().GasEstimator().EIP1559DynamicFees(), nil, evmcfg.EVM().GasEstimator()) + estimator := gas.NewEvmFeeEstimator(lggr, func(lggr logger.Logger) gas.EvmEstimator { + return gas.NewFixedPriceEstimator(evmcfg.EVM().GasEstimator(), nil, evmcfg.EVM().GasEstimator().BlockHistory(), lggr, nil) + }, evmcfg.EVM().GasEstimator().EIP1559DynamicFees(), evmcfg.EVM().GasEstimator()) checkerFactory := &testCheckerFactory{} ge := evmcfg.EVM().GasEstimator() diff --git a/core/chains/evm/txmgr/confirmer_test.go b/core/chains/evm/txmgr/confirmer_test.go index 357dafcbdc4..89e88d5a6dc 100644 --- a/core/chains/evm/txmgr/confirmer_test.go +++ b/core/chains/evm/txmgr/confirmer_test.go @@ -127,7 +127,7 @@ func TestEthConfirmer_Lifecycle(t *testing.T) { newEst := func(logger.Logger) gas.EvmEstimator { return estimator } lggr := logger.Test(t) ge := config.EVM().GasEstimator() - feeEstimator := gas.NewWrappedEvmEstimator(lggr, newEst, ge.EIP1559DynamicFees(), nil, ge) + feeEstimator := gas.NewEvmFeeEstimator(lggr, newEst, ge.EIP1559DynamicFees(), ge) txBuilder := txmgr.NewEvmTxAttemptBuilder(*ethClient.ConfiguredChainID(), ge, ethKeyStore, feeEstimator) ec := txmgr.NewEvmConfirmer(txStore, txmgr.NewEvmTxmClient(ethClient), txmgr.NewEvmTxmConfig(config.EVM()), txmgr.NewEvmTxmFeeConfig(ge), config.EVM().Transactions(), config.Database(), ethKeyStore, txBuilder, lggr) ctx := testutils.Context(t) @@ -1647,7 +1647,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary_WithConnectivityCheck(t *testing newEst := func(logger.Logger) gas.EvmEstimator { return estimator } estimator.On("BumpLegacyGas", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, uint64(0), pkgerrors.Wrapf(commonfee.ErrConnectivity, "transaction...")) ge := ccfg.EVM().GasEstimator() - feeEstimator := gas.NewWrappedEvmEstimator(lggr, newEst, ge.EIP1559DynamicFees(), nil, ge) + feeEstimator := gas.NewEvmFeeEstimator(lggr, newEst, ge.EIP1559DynamicFees(), ge) txBuilder := txmgr.NewEvmTxAttemptBuilder(*ethClient.ConfiguredChainID(), ge, kst, feeEstimator) addresses := []gethCommon.Address{fromAddress} kst.On("EnabledAddressesForChain", mock.Anything, &cltest.FixtureChainID).Return(addresses, nil).Maybe() @@ -1695,7 +1695,7 @@ func TestEthConfirmer_RebroadcastWhereNecessary_WithConnectivityCheck(t *testing newEst := func(logger.Logger) gas.EvmEstimator { return estimator } // Create confirmer with necessary state ge := ccfg.EVM().GasEstimator() - feeEstimator := gas.NewWrappedEvmEstimator(lggr, newEst, ge.EIP1559DynamicFees(), nil, ge) + feeEstimator := gas.NewEvmFeeEstimator(lggr, newEst, ge.EIP1559DynamicFees(), ge) txBuilder := txmgr.NewEvmTxAttemptBuilder(*ethClient.ConfiguredChainID(), ge, kst, feeEstimator) addresses := []gethCommon.Address{fromAddress} kst.On("EnabledAddressesForChain", mock.Anything, &cltest.FixtureChainID).Return(addresses, nil).Maybe() @@ -3133,9 +3133,9 @@ func ptr[T any](t T) *T { return &t } func newEthConfirmer(t testing.TB, txStore txmgr.EvmTxStore, ethClient client.Client, config evmconfig.ChainScopedConfig, ks keystore.Eth, fn txmgrcommon.ResumeCallback) *txmgr.Confirmer { lggr := logger.Test(t) ge := config.EVM().GasEstimator() - estimator := gas.NewWrappedEvmEstimator(lggr, func(lggr logger.Logger) gas.EvmEstimator { - return gas.NewFixedPriceEstimator(ge, ge.BlockHistory(), lggr) - }, ge.EIP1559DynamicFees(), nil, ge) + estimator := gas.NewEvmFeeEstimator(lggr, func(lggr logger.Logger) gas.EvmEstimator { + return gas.NewFixedPriceEstimator(ge, nil, ge.BlockHistory(), lggr, nil) + }, ge.EIP1559DynamicFees(), ge) txBuilder := txmgr.NewEvmTxAttemptBuilder(*ethClient.ConfiguredChainID(), ge, ks, estimator) ec := txmgr.NewEvmConfirmer(txStore, txmgr.NewEvmTxmClient(ethClient), txmgr.NewEvmTxmConfig(config.EVM()), txmgr.NewEvmTxmFeeConfig(ge), config.EVM().Transactions(), config.Database(), ks, txBuilder, lggr) ec.SetResumeCallback(fn) From 8337fc821baf8011c6c73203482db85f5a44d7ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Deividas=20Kar=C5=BEinauskas?= Date: Thu, 18 Apr 2024 20:24:43 +0300 Subject: [PATCH 19/19] Keystone refactoring (#12880) * Try a simple change * Regenerate keystone wrappers * Add changesets * Remove Keystone gas tests * Remove parenthesis --- .changeset/great-rockets-obey.md | 5 +++++ .github/workflows/solidity-foundry.yml | 2 +- contracts/.changeset/mean-items-talk.md | 5 +++++ contracts/gas-snapshots/keystone.gas-snapshot | 2 -- contracts/src/v0.8/keystone/KeystoneForwarder.sol | 12 +++++++++++- .../keystone/generated/forwarder/forwarder.go | 4 ++-- ...rated-wrapper-dependency-versions-do-not-edit.txt | 2 +- 7 files changed, 25 insertions(+), 7 deletions(-) create mode 100644 .changeset/great-rockets-obey.md create mode 100644 contracts/.changeset/mean-items-talk.md delete mode 100644 contracts/gas-snapshots/keystone.gas-snapshot diff --git a/.changeset/great-rockets-obey.md b/.changeset/great-rockets-obey.md new file mode 100644 index 00000000000..b90bc810a01 --- /dev/null +++ b/.changeset/great-rockets-obey.md @@ -0,0 +1,5 @@ +--- +"chainlink": patch +--- + +#wip Keystone wrapper regenerate diff --git a/.github/workflows/solidity-foundry.yml b/.github/workflows/solidity-foundry.yml index cea16f45f16..2d62ef864a5 100644 --- a/.github/workflows/solidity-foundry.yml +++ b/.github/workflows/solidity-foundry.yml @@ -79,7 +79,7 @@ jobs: FOUNDRY_PROFILE: ${{ matrix.product }} - name: Run Forge snapshot - if: ${{ !contains(fromJson('["vrf"]'), matrix.product) && !contains(fromJson('["automation"]'), matrix.product) && needs.changes.outputs.changes == 'true' }} + if: ${{ !contains(fromJson('["vrf"]'), matrix.product) && !contains(fromJson('["automation"]'), matrix.product) && !contains(fromJson('["keystone"]'), matrix.product) && needs.changes.outputs.changes == 'true' }} run: | forge snapshot --nmt "testFuzz_\w{1,}?" --check gas-snapshots/${{ matrix.product }}.gas-snapshot id: snapshot diff --git a/contracts/.changeset/mean-items-talk.md b/contracts/.changeset/mean-items-talk.md new file mode 100644 index 00000000000..e03d49335ad --- /dev/null +++ b/contracts/.changeset/mean-items-talk.md @@ -0,0 +1,5 @@ +--- +"@chainlink/contracts": patch +--- + +#wip Keystone custom error diff --git a/contracts/gas-snapshots/keystone.gas-snapshot b/contracts/gas-snapshots/keystone.gas-snapshot deleted file mode 100644 index 6797bd77e20..00000000000 --- a/contracts/gas-snapshots/keystone.gas-snapshot +++ /dev/null @@ -1,2 +0,0 @@ -KeystoneForwarderTest:test_abi_partial_decoding_works() (gas: 5123) -KeystoneForwarderTest:test_it_works() (gas: 996215) \ No newline at end of file diff --git a/contracts/src/v0.8/keystone/KeystoneForwarder.sol b/contracts/src/v0.8/keystone/KeystoneForwarder.sol index b4a9501e8f4..e6e2675fa2d 100644 --- a/contracts/src/v0.8/keystone/KeystoneForwarder.sol +++ b/contracts/src/v0.8/keystone/KeystoneForwarder.sol @@ -10,6 +10,14 @@ import {Utils} from "./libraries/Utils.sol"; contract KeystoneForwarder is IForwarder, ConfirmedOwner, TypeAndVersionInterface { error ReentrantCall(); + /// @notice This error is returned when the data with report is invalid. + /// This can happen if the data is shorter than SELECTOR_LENGTH + REPORT_LENGTH. + /// @param data the data that was received + error InvalidData(bytes data); + + uint256 private constant SELECTOR_LENGTH = 4; + uint256 private constant REPORT_LENGTH = 64; + struct HotVars { bool reentrancyGuard; // guard against reentrancy } @@ -26,7 +34,9 @@ contract KeystoneForwarder is IForwarder, ConfirmedOwner, TypeAndVersionInterfac bytes calldata data, bytes[] calldata signatures ) external nonReentrant returns (bool) { - require(data.length > 4 + 64, "invalid data length"); + if (data.length < SELECTOR_LENGTH + REPORT_LENGTH) { + revert InvalidData(data); + } // data is an encoded call with the selector prefixed: (bytes4 selector, bytes report, ...) // we are able to partially decode just the first param, since we don't know the rest diff --git a/core/gethwrappers/keystone/generated/forwarder/forwarder.go b/core/gethwrappers/keystone/generated/forwarder/forwarder.go index c66e2886793..c8cf31ae869 100644 --- a/core/gethwrappers/keystone/generated/forwarder/forwarder.go +++ b/core/gethwrappers/keystone/generated/forwarder/forwarder.go @@ -31,8 +31,8 @@ var ( ) var KeystoneForwarderMetaData = &bind.MetaData{ - ABI: "[{\"inputs\":[],\"stateMutability\":\"nonpayable\",\"type\":\"constructor\"},{\"inputs\":[],\"name\":\"ReentrantCall\",\"type\":\"error\"},{\"anonymous\":false,\"inputs\":[{\"indexed\":true,\"internalType\":\"address\",\"name\":\"from\",\"type\":\"address\"},{\"indexed\":true,\"internalType\":\"address\",\"name\":\"to\",\"type\":\"address\"}],\"name\":\"OwnershipTransferRequested\",\"type\":\"event\"},{\"anonymous\":false,\"inputs\":[{\"indexed\":true,\"internalType\":\"address\",\"name\":\"from\",\"type\":\"address\"},{\"indexed\":true,\"internalType\":\"address\",\"name\":\"to\",\"type\":\"address\"}],\"name\":\"OwnershipTransferred\",\"type\":\"event\"},{\"inputs\":[],\"name\":\"acceptOwnership\",\"outputs\":[],\"stateMutability\":\"nonpayable\",\"type\":\"function\"},{\"inputs\":[{\"internalType\":\"bytes32\",\"name\":\"workflowExecutionId\",\"type\":\"bytes32\"}],\"name\":\"getTransmitter\",\"outputs\":[{\"internalType\":\"address\",\"name\":\"\",\"type\":\"address\"}],\"stateMutability\":\"view\",\"type\":\"function\"},{\"inputs\":[],\"name\":\"owner\",\"outputs\":[{\"internalType\":\"address\",\"name\":\"\",\"type\":\"address\"}],\"stateMutability\":\"view\",\"type\":\"function\"},{\"inputs\":[{\"internalType\":\"address\",\"name\":\"targetAddress\",\"type\":\"address\"},{\"internalType\":\"bytes\",\"name\":\"data\",\"type\":\"bytes\"},{\"internalType\":\"bytes[]\",\"name\":\"signatures\",\"type\":\"bytes[]\"}],\"name\":\"report\",\"outputs\":[{\"internalType\":\"bool\",\"name\":\"\",\"type\":\"bool\"}],\"stateMutability\":\"nonpayable\",\"type\":\"function\"},{\"inputs\":[{\"internalType\":\"address\",\"name\":\"to\",\"type\":\"address\"}],\"name\":\"transferOwnership\",\"outputs\":[],\"stateMutability\":\"nonpayable\",\"type\":\"function\"},{\"inputs\":[],\"name\":\"typeAndVersion\",\"outputs\":[{\"internalType\":\"string\",\"name\":\"\",\"type\":\"string\"}],\"stateMutability\":\"pure\",\"type\":\"function\"}]", - Bin: "0x608060405234801561001057600080fd5b5033806000816100675760405162461bcd60e51b815260206004820152601860248201527f43616e6e6f7420736574206f776e657220746f207a65726f000000000000000060448201526064015b60405180910390fd5b600080546001600160a01b0319166001600160a01b0384811691909117909155811615610097576100978161009f565b505050610148565b336001600160a01b038216036100f75760405162461bcd60e51b815260206004820152601760248201527f43616e6e6f74207472616e7366657220746f2073656c66000000000000000000604482015260640161005e565b600180546001600160a01b0319166001600160a01b0383811691821790925560008054604051929316917fed8889f560326eb138920d842192f0eb3dd22b4f139c87a2c57538e05bae12789190a350565b610c12806101576000396000f3fe608060405234801561001057600080fd5b50600436106100725760003560e01c8063c0965dc311610050578063c0965dc314610108578063e6b714581461012b578063f2fde38b1461016157600080fd5b8063181f5a771461007757806379ba5097146100bf5780638da5cb5b146100c9575b600080fd5b604080518082018252601781527f4b657973746f6e65466f7277617264657220312e302e30000000000000000000602082015290516100b69190610827565b60405180910390f35b6100c7610174565b005b60005473ffffffffffffffffffffffffffffffffffffffff165b60405173ffffffffffffffffffffffffffffffffffffffff90911681526020016100b6565b61011b6101163660046108bc565b610276565b60405190151581526020016100b6565b6100e3610139366004610998565b60009081526003602052604090205473ffffffffffffffffffffffffffffffffffffffff1690565b6100c761016f3660046109b1565b61058e565b60015473ffffffffffffffffffffffffffffffffffffffff1633146101fa576040517f08c379a000000000000000000000000000000000000000000000000000000000815260206004820152601660248201527f4d7573742062652070726f706f736564206f776e65720000000000000000000060448201526064015b60405180910390fd5b60008054337fffffffffffffffffffffffff00000000000000000000000000000000000000008083168217845560018054909116905560405173ffffffffffffffffffffffffffffffffffffffff90921692909183917f8be0079c531659141344cd1fd0a4f28419497f9722a3daafe3b4186f6b6457e091a350565b60025460009060ff16156102b6576040517f37ed32e800000000000000000000000000000000000000000000000000000000815260040160405180910390fd5b600280547fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff001660011790556044841161034b576040517f08c379a000000000000000000000000000000000000000000000000000000000815260206004820152601360248201527f696e76616c69642064617461206c656e6774680000000000000000000000000060448201526064016101f1565b600061035a85600481896109d3565b8101906103679190610a2c565b8051602082012090915060005b848110156104655760008060006103e289898681811061039657610396610afb565b90506020028101906103a89190610b2a565b8080601f0160208091040260200160405190810160405280939291908181526020018383808284376000920191909152506105a292505050565b925092509250600060018683868660405160008152602001604052604051610426949392919093845260ff9290921660208401526040830152606082015260800190565b6020604051602081039080840390855afa158015610448573d6000803e3d6000fd5b5086955061045d9450859350610b9692505050565b915050610374565b5060008061047284610630565b600081815260036020526040902054919350915073ffffffffffffffffffffffffffffffffffffffff16156104ae57600094505050505061055d565b6000808b73ffffffffffffffffffffffffffffffffffffffff168b8b6040516104d8929190610bf5565b6000604051808303816000865af19150503d8060008114610515576040519150601f19603f3d011682016040523d82523d6000602084013e61051a565b606091505b5050506000928352505060036020526040902080547fffffffffffffffffffffffff00000000000000000000000000000000000000001633179055506001925050505b600280547fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0016905595945050505050565b6105966106af565b61059f81610732565b50565b60008060008351604114610612576040517f08c379a000000000000000000000000000000000000000000000000000000000815260206004820152601860248201527f696e76616c6964207369676e6174757265206c656e677468000000000000000060448201526064016101f1565b50505060208101516040820151606090920151909260009190911a90565b600080604083511161069e576040517f08c379a000000000000000000000000000000000000000000000000000000000815260206004820152601560248201527f696e76616c6964207265706f7274206c656e677468000000000000000000000060448201526064016101f1565b505060208101516040909101519091565b60005473ffffffffffffffffffffffffffffffffffffffff163314610730576040517f08c379a000000000000000000000000000000000000000000000000000000000815260206004820152601660248201527f4f6e6c792063616c6c61626c65206279206f776e65720000000000000000000060448201526064016101f1565b565b3373ffffffffffffffffffffffffffffffffffffffff8216036107b1576040517f08c379a000000000000000000000000000000000000000000000000000000000815260206004820152601760248201527f43616e6e6f74207472616e7366657220746f2073656c6600000000000000000060448201526064016101f1565b600180547fffffffffffffffffffffffff00000000000000000000000000000000000000001673ffffffffffffffffffffffffffffffffffffffff83811691821790925560008054604051929316917fed8889f560326eb138920d842192f0eb3dd22b4f139c87a2c57538e05bae12789190a350565b600060208083528351808285015260005b8181101561085457858101830151858201604001528201610838565b5060006040828601015260407fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe0601f8301168501019250505092915050565b803573ffffffffffffffffffffffffffffffffffffffff811681146108b757600080fd5b919050565b6000806000806000606086880312156108d457600080fd5b6108dd86610893565b9450602086013567ffffffffffffffff808211156108fa57600080fd5b818801915088601f83011261090e57600080fd5b81358181111561091d57600080fd5b89602082850101111561092f57600080fd5b60208301965080955050604088013591508082111561094d57600080fd5b818801915088601f83011261096157600080fd5b81358181111561097057600080fd5b8960208260051b850101111561098557600080fd5b9699959850939650602001949392505050565b6000602082840312156109aa57600080fd5b5035919050565b6000602082840312156109c357600080fd5b6109cc82610893565b9392505050565b600080858511156109e357600080fd5b838611156109f057600080fd5b5050820193919092039150565b7f4e487b7100000000000000000000000000000000000000000000000000000000600052604160045260246000fd5b600060208284031215610a3e57600080fd5b813567ffffffffffffffff80821115610a5657600080fd5b818401915084601f830112610a6a57600080fd5b813581811115610a7c57610a7c6109fd565b604051601f82017fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe0908116603f01168101908382118183101715610ac257610ac26109fd565b81604052828152876020848701011115610adb57600080fd5b826020860160208301376000928101602001929092525095945050505050565b7f4e487b7100000000000000000000000000000000000000000000000000000000600052603260045260246000fd5b60008083357fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe1843603018112610b5f57600080fd5b83018035915067ffffffffffffffff821115610b7a57600080fd5b602001915036819003821315610b8f57600080fd5b9250929050565b60007fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff8203610bee577f4e487b7100000000000000000000000000000000000000000000000000000000600052601160045260246000fd5b5060010190565b818382376000910190815291905056fea164736f6c6343000813000a", + ABI: "[{\"inputs\":[],\"stateMutability\":\"nonpayable\",\"type\":\"constructor\"},{\"inputs\":[{\"internalType\":\"bytes\",\"name\":\"data\",\"type\":\"bytes\"}],\"name\":\"InvalidData\",\"type\":\"error\"},{\"inputs\":[],\"name\":\"ReentrantCall\",\"type\":\"error\"},{\"anonymous\":false,\"inputs\":[{\"indexed\":true,\"internalType\":\"address\",\"name\":\"from\",\"type\":\"address\"},{\"indexed\":true,\"internalType\":\"address\",\"name\":\"to\",\"type\":\"address\"}],\"name\":\"OwnershipTransferRequested\",\"type\":\"event\"},{\"anonymous\":false,\"inputs\":[{\"indexed\":true,\"internalType\":\"address\",\"name\":\"from\",\"type\":\"address\"},{\"indexed\":true,\"internalType\":\"address\",\"name\":\"to\",\"type\":\"address\"}],\"name\":\"OwnershipTransferred\",\"type\":\"event\"},{\"inputs\":[],\"name\":\"acceptOwnership\",\"outputs\":[],\"stateMutability\":\"nonpayable\",\"type\":\"function\"},{\"inputs\":[{\"internalType\":\"bytes32\",\"name\":\"workflowExecutionId\",\"type\":\"bytes32\"}],\"name\":\"getTransmitter\",\"outputs\":[{\"internalType\":\"address\",\"name\":\"\",\"type\":\"address\"}],\"stateMutability\":\"view\",\"type\":\"function\"},{\"inputs\":[],\"name\":\"owner\",\"outputs\":[{\"internalType\":\"address\",\"name\":\"\",\"type\":\"address\"}],\"stateMutability\":\"view\",\"type\":\"function\"},{\"inputs\":[{\"internalType\":\"address\",\"name\":\"targetAddress\",\"type\":\"address\"},{\"internalType\":\"bytes\",\"name\":\"data\",\"type\":\"bytes\"},{\"internalType\":\"bytes[]\",\"name\":\"signatures\",\"type\":\"bytes[]\"}],\"name\":\"report\",\"outputs\":[{\"internalType\":\"bool\",\"name\":\"\",\"type\":\"bool\"}],\"stateMutability\":\"nonpayable\",\"type\":\"function\"},{\"inputs\":[{\"internalType\":\"address\",\"name\":\"to\",\"type\":\"address\"}],\"name\":\"transferOwnership\",\"outputs\":[],\"stateMutability\":\"nonpayable\",\"type\":\"function\"},{\"inputs\":[],\"name\":\"typeAndVersion\",\"outputs\":[{\"internalType\":\"string\",\"name\":\"\",\"type\":\"string\"}],\"stateMutability\":\"pure\",\"type\":\"function\"}]", + Bin: "0x608060405234801561001057600080fd5b5033806000816100675760405162461bcd60e51b815260206004820152601860248201527f43616e6e6f7420736574206f776e657220746f207a65726f000000000000000060448201526064015b60405180910390fd5b600080546001600160a01b0319166001600160a01b0384811691909117909155811615610097576100978161009f565b505050610148565b336001600160a01b038216036100f75760405162461bcd60e51b815260206004820152601760248201527f43616e6e6f74207472616e7366657220746f2073656c66000000000000000000604482015260640161005e565b600180546001600160a01b0319166001600160a01b0383811691821790925560008054604051929316917fed8889f560326eb138920d842192f0eb3dd22b4f139c87a2c57538e05bae12789190a350565b610c5f806101576000396000f3fe608060405234801561001057600080fd5b50600436106100725760003560e01c8063c0965dc311610050578063c0965dc314610108578063e6b714581461012b578063f2fde38b1461016157600080fd5b8063181f5a771461007757806379ba5097146100bf5780638da5cb5b146100c9575b600080fd5b604080518082018252601781527f4b657973746f6e65466f7277617264657220312e302e30000000000000000000602082015290516100b69190610806565b60405180910390f35b6100c7610174565b005b60005473ffffffffffffffffffffffffffffffffffffffff165b60405173ffffffffffffffffffffffffffffffffffffffff90911681526020016100b6565b61011b61011636600461089b565b610276565b60405190151581526020016100b6565b6100e3610139366004610977565b60009081526003602052604090205473ffffffffffffffffffffffffffffffffffffffff1690565b6100c761016f366004610990565b61056d565b60015473ffffffffffffffffffffffffffffffffffffffff1633146101fa576040517f08c379a000000000000000000000000000000000000000000000000000000000815260206004820152601660248201527f4d7573742062652070726f706f736564206f776e65720000000000000000000060448201526064015b60405180910390fd5b60008054337fffffffffffffffffffffffff00000000000000000000000000000000000000008083168217845560018054909116905560405173ffffffffffffffffffffffffffffffffffffffff90921692909183917f8be0079c531659141344cd1fd0a4f28419497f9722a3daafe3b4186f6b6457e091a350565b60025460009060ff16156102b6576040517f37ed32e800000000000000000000000000000000000000000000000000000000815260040160405180910390fd5b600280547fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff001660011790556102ed604060046109e1565b84101561032a5784846040517f2a62609b0000000000000000000000000000000000000000000000000000000081526004016101f19291906109fa565b60006103398560048189610a47565b8101906103469190610aa0565b8051602082012090915060005b848110156104445760008060006103c189898681811061037557610375610b6f565b90506020028101906103879190610b9e565b8080601f01602080910402602001604051908101604052809392919081815260200183838082843760009201919091525061058192505050565b925092509250600060018683868660405160008152602001604052604051610405949392919093845260ff9290921660208401526040830152606082015260800190565b6020604051602081039080840390855afa158015610427573d6000803e3d6000fd5b5086955061043c9450859350610c0a92505050565b915050610353565b506000806104518461060f565b600081815260036020526040902054919350915073ffffffffffffffffffffffffffffffffffffffff161561048d57600094505050505061053c565b6000808b73ffffffffffffffffffffffffffffffffffffffff168b8b6040516104b7929190610c42565b6000604051808303816000865af19150503d80600081146104f4576040519150601f19603f3d011682016040523d82523d6000602084013e6104f9565b606091505b5050506000928352505060036020526040902080547fffffffffffffffffffffffff00000000000000000000000000000000000000001633179055506001925050505b600280547fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff0016905595945050505050565b61057561068e565b61057e81610711565b50565b600080600083516041146105f1576040517f08c379a000000000000000000000000000000000000000000000000000000000815260206004820152601860248201527f696e76616c6964207369676e6174757265206c656e677468000000000000000060448201526064016101f1565b50505060208101516040820151606090920151909260009190911a90565b600080604083511161067d576040517f08c379a000000000000000000000000000000000000000000000000000000000815260206004820152601560248201527f696e76616c6964207265706f7274206c656e677468000000000000000000000060448201526064016101f1565b505060208101516040909101519091565b60005473ffffffffffffffffffffffffffffffffffffffff16331461070f576040517f08c379a000000000000000000000000000000000000000000000000000000000815260206004820152601660248201527f4f6e6c792063616c6c61626c65206279206f776e65720000000000000000000060448201526064016101f1565b565b3373ffffffffffffffffffffffffffffffffffffffff821603610790576040517f08c379a000000000000000000000000000000000000000000000000000000000815260206004820152601760248201527f43616e6e6f74207472616e7366657220746f2073656c6600000000000000000060448201526064016101f1565b600180547fffffffffffffffffffffffff00000000000000000000000000000000000000001673ffffffffffffffffffffffffffffffffffffffff83811691821790925560008054604051929316917fed8889f560326eb138920d842192f0eb3dd22b4f139c87a2c57538e05bae12789190a350565b600060208083528351808285015260005b8181101561083357858101830151858201604001528201610817565b5060006040828601015260407fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe0601f8301168501019250505092915050565b803573ffffffffffffffffffffffffffffffffffffffff8116811461089657600080fd5b919050565b6000806000806000606086880312156108b357600080fd5b6108bc86610872565b9450602086013567ffffffffffffffff808211156108d957600080fd5b818801915088601f8301126108ed57600080fd5b8135818111156108fc57600080fd5b89602082850101111561090e57600080fd5b60208301965080955050604088013591508082111561092c57600080fd5b818801915088601f83011261094057600080fd5b81358181111561094f57600080fd5b8960208260051b850101111561096457600080fd5b9699959850939650602001949392505050565b60006020828403121561098957600080fd5b5035919050565b6000602082840312156109a257600080fd5b6109ab82610872565b9392505050565b7f4e487b7100000000000000000000000000000000000000000000000000000000600052601160045260246000fd5b808201808211156109f4576109f46109b2565b92915050565b60208152816020820152818360408301376000818301604090810191909152601f9092017fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe0160101919050565b60008085851115610a5757600080fd5b83861115610a6457600080fd5b5050820193919092039150565b7f4e487b7100000000000000000000000000000000000000000000000000000000600052604160045260246000fd5b600060208284031215610ab257600080fd5b813567ffffffffffffffff80821115610aca57600080fd5b818401915084601f830112610ade57600080fd5b813581811115610af057610af0610a71565b604051601f82017fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe0908116603f01168101908382118183101715610b3657610b36610a71565b81604052828152876020848701011115610b4f57600080fd5b826020860160208301376000928101602001929092525095945050505050565b7f4e487b7100000000000000000000000000000000000000000000000000000000600052603260045260246000fd5b60008083357fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe1843603018112610bd357600080fd5b83018035915067ffffffffffffffff821115610bee57600080fd5b602001915036819003821315610c0357600080fd5b9250929050565b60007fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff8203610c3b57610c3b6109b2565b5060010190565b818382376000910190815291905056fea164736f6c6343000813000a", } var KeystoneForwarderABI = KeystoneForwarderMetaData.ABI diff --git a/core/gethwrappers/keystone/generation/generated-wrapper-dependency-versions-do-not-edit.txt b/core/gethwrappers/keystone/generation/generated-wrapper-dependency-versions-do-not-edit.txt index b9d8bfbfefc..8b1c830405d 100644 --- a/core/gethwrappers/keystone/generation/generated-wrapper-dependency-versions-do-not-edit.txt +++ b/core/gethwrappers/keystone/generation/generated-wrapper-dependency-versions-do-not-edit.txt @@ -1,3 +1,3 @@ GETH_VERSION: 1.13.8 -forwarder: ../../../contracts/solc/v0.8.19/KeystoneForwarder/KeystoneForwarder.abi ../../../contracts/solc/v0.8.19/KeystoneForwarder/KeystoneForwarder.bin 4886b538e1fdc8aaf860901de36269e0c35acfd3e6eb190654d693ff9dbd4b6d +forwarder: ../../../contracts/solc/v0.8.19/KeystoneForwarder/KeystoneForwarder.abi ../../../contracts/solc/v0.8.19/KeystoneForwarder/KeystoneForwarder.bin b4c900aae9e022f01abbac7993d41f93912247613ac6270b0c4da4ef6f2016e3 ocr3_capability: ../../../contracts/solc/v0.8.19/OCR3Capability/OCR3Capability.abi ../../../contracts/solc/v0.8.19/OCR3Capability/OCR3Capability.bin 9dcbdf55bd5729ba266148da3f17733eb592c871c2108ccca546618628fd9ad2