diff --git a/core/capabilities/ccip/launcher/integration_test.go b/core/capabilities/ccip/launcher/integration_test.go index f0a4bd46bb3..954fda03969 100644 --- a/core/capabilities/ccip/launcher/integration_test.go +++ b/core/capabilities/ccip/launcher/integration_test.go @@ -1,6 +1,7 @@ package launcher import ( + "context" "testing" "time" @@ -115,7 +116,7 @@ type oracleCreatorPrints struct { t *testing.T } -func (o *oracleCreatorPrints) Create(_ uint32, config cctypes.OCR3ConfigWithMeta) (cctypes.CCIPOracle, error) { +func (o *oracleCreatorPrints) Create(ctx context.Context, _ uint32, config cctypes.OCR3ConfigWithMeta) (cctypes.CCIPOracle, error) { pluginType := cctypes.PluginType(config.Config.PluginType) o.t.Logf("Creating plugin oracle (pluginType: %s) with config %+v\n", pluginType, config) return &oraclePrints{pluginType: pluginType, config: config, t: o.t}, nil diff --git a/core/capabilities/ccip/launcher/launcher.go b/core/capabilities/ccip/launcher/launcher.go index 167ac0a815e..76a6c204058 100644 --- a/core/capabilities/ccip/launcher/launcher.go +++ b/core/capabilities/ccip/launcher/launcher.go @@ -68,7 +68,7 @@ type launcher struct { myP2PID ragep2ptypes.PeerID lggr logger.Logger homeChainReader ccipreader.HomeChain - stopChan chan struct{} + stopChan services.StopChan // latestState is the latest capability registry state received from the syncer. latestState registrysyncer.LocalRegistry // regState is the latest capability registry state that we have successfully processed. @@ -140,12 +140,16 @@ func (l *launcher) Start(context.Context) error { func (l *launcher) monitor() { defer l.wg.Done() ticker := time.NewTicker(l.tickInterval) + + ctx, cancel := l.stopChan.NewCtx() + defer cancel() + for { select { - case <-l.stopChan: + case <-ctx.Done(): return case <-ticker.C: - if err := l.tick(); err != nil { + if err := l.tick(ctx); err != nil { l.lggr.Errorw("Failed to tick", "err", err) } } @@ -154,7 +158,7 @@ func (l *launcher) monitor() { // tick gets the latest registry state and processes the diff between the current and latest state. // This may lead to starting or stopping OCR instances. -func (l *launcher) tick() error { +func (l *launcher) tick(ctx context.Context) error { // Ensure that the home chain reader is healthy. // For new jobs it may be possible that the home chain reader is not yet ready // so we won't be able to fetch configs and start any OCR instances. @@ -171,7 +175,7 @@ func (l *launcher) tick() error { return fmt.Errorf("failed to diff capability registry states: %w", err) } - err = l.processDiff(diffRes) + err = l.processDiff(ctx, diffRes) if err != nil { return fmt.Errorf("failed to process diff: %w", err) } @@ -183,17 +187,17 @@ func (l *launcher) tick() error { // for any added OCR instances, it will launch them. // for any removed OCR instances, it will shut them down. // for any updated OCR instances, it will restart them with the new configuration. -func (l *launcher) processDiff(diff diffResult) error { +func (l *launcher) processDiff(ctx context.Context, diff diffResult) error { err := l.processRemoved(diff.removed) - err = multierr.Append(err, l.processAdded(diff.added)) - err = multierr.Append(err, l.processUpdate(diff.updated)) + err = multierr.Append(err, l.processAdded(ctx, diff.added)) + err = multierr.Append(err, l.processUpdate(ctx, diff.updated)) return err } // processUpdate will manage when configurations of an existing don are updated // If new oracles are needed, they are created and started. Old ones will be shut down -func (l *launcher) processUpdate(updated map[registrysyncer.DonID]registrysyncer.DON) error { +func (l *launcher) processUpdate(ctx context.Context, updated map[registrysyncer.DonID]registrysyncer.DON) error { l.lock.Lock() defer l.lock.Unlock() @@ -203,12 +207,13 @@ func (l *launcher) processUpdate(updated map[registrysyncer.DonID]registrysyncer return fmt.Errorf("invariant violation: expected to find CCIP DON %d in the map of running deployments", don.ID) } - latestConfigs, err := getConfigsForDon(l.homeChainReader, don) + latestConfigs, err := getConfigsForDon(ctx, l.homeChainReader, don) if err != nil { return err } newPlugins, err := updateDON( + ctx, l.lggr, l.myP2PID, prevPlugins, @@ -233,16 +238,17 @@ func (l *launcher) processUpdate(updated map[registrysyncer.DonID]registrysyncer // processAdded is for when a new don is created. We know that all oracles // must be created and started -func (l *launcher) processAdded(added map[registrysyncer.DonID]registrysyncer.DON) error { +func (l *launcher) processAdded(ctx context.Context, added map[registrysyncer.DonID]registrysyncer.DON) error { l.lock.Lock() defer l.lock.Unlock() for donID, don := range added { - configs, err := getConfigsForDon(l.homeChainReader, don) + configs, err := getConfigsForDon(ctx, l.homeChainReader, don) if err != nil { return fmt.Errorf("failed to get current configs for don %d: %w", donID, err) } newPlugins, err := createDON( + ctx, l.lggr, l.myP2PID, don, @@ -300,6 +306,7 @@ func (l *launcher) processRemoved(removed map[registrysyncer.DonID]registrysynce } func updateDON( + ctx context.Context, lggr logger.Logger, p2pID ragep2ptypes.PeerID, prevPlugins pluginRegistry, @@ -318,7 +325,7 @@ func updateDON( for _, c := range latestConfigs { digest := c.ConfigDigest if _, ok := prevPlugins[digest]; !ok { - oracle, err := oracleCreator.Create(don.ID, cctypes.OCR3ConfigWithMeta(c)) + oracle, err := oracleCreator.Create(ctx, don.ID, cctypes.OCR3ConfigWithMeta(c)) if err != nil { return nil, fmt.Errorf("failed to create CCIP oracle: %w for digest %x", err, digest) } @@ -335,6 +342,7 @@ func updateDON( // createDON is a pure function that handles the case where a new DON is added to the capability registry. // It returns up to 4 plugins that are later started. func createDON( + ctx context.Context, lggr logger.Logger, p2pID ragep2ptypes.PeerID, don registrysyncer.DON, @@ -352,7 +360,7 @@ func createDON( return nil, fmt.Errorf("digest does not match type %w", err) } - oracle, err := oracleCreator.Create(don.ID, cctypes.OCR3ConfigWithMeta(config)) + oracle, err := oracleCreator.Create(ctx, don.ID, cctypes.OCR3ConfigWithMeta(config)) if err != nil { return nil, fmt.Errorf("failed to create CCIP oracle: %w for digest %x", err, digest) } @@ -363,16 +371,17 @@ func createDON( } func getConfigsForDon( + ctx context.Context, homeChainReader ccipreader.HomeChain, don registrysyncer.DON) ([]ccipreader.OCR3ConfigWithMeta, error) { // this should be a retryable error. - commitOCRConfigs, err := homeChainReader.GetOCRConfigs(context.Background(), don.ID, uint8(cctypes.PluginTypeCCIPCommit)) + commitOCRConfigs, err := homeChainReader.GetOCRConfigs(ctx, don.ID, uint8(cctypes.PluginTypeCCIPCommit)) if err != nil { return nil, fmt.Errorf("failed to fetch OCR configs for CCIP commit plugin (don id: %d) from home chain config contract: %w", don.ID, err) } - execOCRConfigs, err := homeChainReader.GetOCRConfigs(context.Background(), don.ID, uint8(cctypes.PluginTypeCCIPExec)) + execOCRConfigs, err := homeChainReader.GetOCRConfigs(ctx, don.ID, uint8(cctypes.PluginTypeCCIPExec)) if err != nil { return nil, fmt.Errorf("failed to fetch OCR configs for CCIP exec plugin (don id: %d) from home chain config contract: %w", don.ID, err) diff --git a/core/capabilities/ccip/launcher/launcher_test.go b/core/capabilities/ccip/launcher/launcher_test.go index 188ee48c215..3e3bd1a4368 100644 --- a/core/capabilities/ccip/launcher/launcher_test.go +++ b/core/capabilities/ccip/launcher/launcher_test.go @@ -8,6 +8,7 @@ import ( cctypes "github.com/smartcontractkit/chainlink/v2/core/capabilities/ccip/types" "github.com/smartcontractkit/chainlink/v2/core/capabilities/ccip/types/mocks" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" ragep2ptypes "github.com/smartcontractkit/libocr/ragep2p/types" "github.com/stretchr/testify/mock" @@ -113,7 +114,7 @@ func Test_createDON(t *testing.T) { }, }, nil) oracleCreator.EXPECT().Type().Return(cctypes.OracleTypeBootstrap).Once() - oracleCreator.EXPECT().Create(mock.Anything, mock.Anything).Return(mocks.NewCCIPOracle(t), nil).Twice() + oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.Anything).Return(mocks.NewCCIPOracle(t), nil).Twice() }, false, }, @@ -153,11 +154,11 @@ func Test_createDON(t *testing.T) { }, }, nil) - oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { + oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPCommit) })). Return(mocks.NewCCIPOracle(t), nil) - oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { + oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPExec) })). Return(mocks.NewCCIPOracle(t), nil) @@ -212,11 +213,11 @@ func Test_createDON(t *testing.T) { }, }, nil) - oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { + oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPCommit) })). Return(mocks.NewCCIPOracle(t), nil).Twice() - oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { + oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPExec) })). Return(mocks.NewCCIPOracle(t), nil).Twice() @@ -229,10 +230,11 @@ func Test_createDON(t *testing.T) { if tt.expect != nil { tt.expect(t, tt.args, tt.args.oracleCreator, tt.args.homeChainReader) } + ctx := testutils.Context(t) - latestConfigs, err := getConfigsForDon(tt.args.homeChainReader, tt.args.don) + latestConfigs, err := getConfigsForDon(ctx, tt.args.homeChainReader, tt.args.don) require.NoError(t, err) - _, err = createDON(tt.args.lggr, tt.args.p2pID, tt.args.don, tt.args.oracleCreator, latestConfigs) + _, err = createDON(ctx, tt.args.lggr, tt.args.p2pID, tt.args.don, tt.args.oracleCreator, latestConfigs) if tt.wantErr { require.Error(t, err) } else { @@ -304,11 +306,11 @@ func Test_updateDON(t *testing.T) { ConfigDigest: utils.RandomBytes32(), }, }, nil) - oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { + oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPCommit) })). Return(mocks.NewCCIPOracle(t), nil) - oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { + oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPExec) })). Return(mocks.NewCCIPOracle(t), nil) @@ -405,11 +407,11 @@ func Test_updateDON(t *testing.T) { ConfigDigest: utils.RandomBytes32(), }, }, nil) - oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { + oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPCommit) })). Return(mocks.NewCCIPOracle(t), nil).Once() - oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { + oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPExec) })). Return(mocks.NewCCIPOracle(t), nil).Once() @@ -472,11 +474,11 @@ func Test_updateDON(t *testing.T) { ConfigDigest: utils.RandomBytes32(), }, }, nil) - oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { + oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPCommit) })). Return(mocks.NewCCIPOracle(t), nil).Twice() - oracleCreator.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { + oracleCreator.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPExec) })). Return(mocks.NewCCIPOracle(t), nil).Twice() @@ -489,10 +491,11 @@ func Test_updateDON(t *testing.T) { if tt.expect != nil { tt.expect(t, tt.args, tt.args.oracleCreator, tt.args.homeChainReader) } + ctx := testutils.Context(t) - latestConfigs, err := getConfigsForDon(tt.args.homeChainReader, tt.args.don) + latestConfigs, err := getConfigsForDon(ctx, tt.args.homeChainReader, tt.args.don) require.NoError(t, err) - newPlugins, err := updateDON(tt.args.lggr, tt.args.p2pID, tt.args.prevPlugins, tt.args.don, tt.args.oracleCreator, latestConfigs) + newPlugins, err := updateDON(ctx, tt.args.lggr, tt.args.p2pID, tt.args.prevPlugins, tt.args.don, tt.args.oracleCreator, latestConfigs) if (err != nil) != tt.wantErr { t.Errorf("updateDON() error = %v, wantErr %v", err, tt.wantErr) return @@ -602,11 +605,11 @@ func Test_launcher_processDiff(t *testing.T) { commitOracle.On("Start").Return(nil) execOracle := mocks.NewCCIPOracle(t) execOracle.On("Start").Return(nil) - m.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { + m.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPCommit) })). Return(commitOracle, nil) - m.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { + m.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPExec) })). Return(execOracle, nil) @@ -679,11 +682,11 @@ func Test_launcher_processDiff(t *testing.T) { commitOracle.On("Start").Return(nil) execOracle := mocks.NewCCIPOracle(t) execOracle.On("Start").Return(nil) - m.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { + m.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPCommit) })). Return(commitOracle, nil) - m.EXPECT().Create(mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { + m.EXPECT().Create(mock.Anything, mock.Anything, mock.MatchedBy(func(cfg cctypes.OCR3ConfigWithMeta) bool { return cfg.Config.PluginType == uint8(cctypes.PluginTypeCCIPExec) })). Return(execOracle, nil) @@ -733,7 +736,7 @@ func Test_launcher_processDiff(t *testing.T) { homeChainReader: tt.fields.homeChainReader, oracleCreator: tt.fields.oracleCreator, } - err := l.processDiff(tt.args.diff) + err := l.processDiff(testutils.Context(t), tt.args.diff) if tt.wantErr { require.Error(t, err) } else { diff --git a/core/capabilities/ccip/oraclecreator/bootstrap.go b/core/capabilities/ccip/oraclecreator/bootstrap.go index 44ed824e569..632ac789c8e 100644 --- a/core/capabilities/ccip/oraclecreator/bootstrap.go +++ b/core/capabilities/ccip/oraclecreator/bootstrap.go @@ -140,7 +140,7 @@ func (i *bootstrapOracleCreator) Type() cctypes.OracleType { } // Create implements types.OracleCreator. -func (i *bootstrapOracleCreator) Create(_ uint32, config cctypes.OCR3ConfigWithMeta) (cctypes.CCIPOracle, error) { +func (i *bootstrapOracleCreator) Create(ctx context.Context, _ uint32, config cctypes.OCR3ConfigWithMeta) (cctypes.CCIPOracle, error) { // Assuming that the chain selector is referring to an evm chain for now. // TODO: add an api that returns chain family. // NOTE: this doesn't really matter for the bootstrap node, it doesn't do anything on-chain. @@ -158,7 +158,6 @@ func (i *bootstrapOracleCreator) Create(_ uint32, config cctypes.OCR3ConfigWithM oraclePeerIDs = append(oraclePeerIDs, n.P2pID) } - ctx := context.Background() rmnHomeReader, err := i.getRmnHomeReader(ctx, config) if err != nil { return nil, fmt.Errorf("failed to get RMNHome reader: %w", err) diff --git a/core/capabilities/ccip/oraclecreator/plugin.go b/core/capabilities/ccip/oraclecreator/plugin.go index ea08f150f10..5df0b1135d7 100644 --- a/core/capabilities/ccip/oraclecreator/plugin.go +++ b/core/capabilities/ccip/oraclecreator/plugin.go @@ -118,7 +118,7 @@ func (i *pluginOracleCreator) Type() cctypes.OracleType { } // Create implements types.OracleCreator. -func (i *pluginOracleCreator) Create(donID uint32, config cctypes.OCR3ConfigWithMeta) (cctypes.CCIPOracle, error) { +func (i *pluginOracleCreator) Create(ctx context.Context, donID uint32, config cctypes.OCR3ConfigWithMeta) (cctypes.CCIPOracle, error) { pluginType := cctypes.PluginType(config.Config.PluginType) // Assuming that the chain selector is referring to an evm chain for now. @@ -137,6 +137,7 @@ func (i *pluginOracleCreator) Create(donID uint32, config cctypes.OCR3ConfigWith } contractReaders, chainWriters, err := i.createReadersAndWriters( + ctx, destChainID, pluginType, config, @@ -294,6 +295,7 @@ func (i *pluginOracleCreator) createFactoryAndTransmitter( } func (i *pluginOracleCreator) createReadersAndWriters( + ctx context.Context, destChainID uint64, pluginType cctypes.PluginType, config cctypes.OCR3ConfigWithMeta, @@ -340,15 +342,14 @@ func (i *pluginOracleCreator) createReadersAndWriters( return nil, nil, fmt.Errorf("failed to get chain reader config: %w", err1) } - // TODO: context. - cr, err1 := relayer.NewContractReader(context.Background(), chainReaderConfig) + cr, err1 := relayer.NewContractReader(ctx, chainReaderConfig) if err1 != nil { return nil, nil, err1 } if chainID.Uint64() == destChainID { offrampAddressHex := common.BytesToAddress(config.Config.OfframpAddress).Hex() - err2 := cr.Bind(context.Background(), []types.BoundContract{ + err2 := cr.Bind(ctx, []types.BoundContract{ { Address: offrampAddressHex, Name: consts.ContractNameOffRamp, @@ -359,11 +360,12 @@ func (i *pluginOracleCreator) createReadersAndWriters( } } - if err2 := cr.Start(context.Background()); err2 != nil { + if err2 := cr.Start(ctx); err2 != nil { return nil, nil, fmt.Errorf("failed to start contract reader for chain %s: %w", chainID.String(), err2) } cw, err1 := createChainWriter( + ctx, chainID, i.evmConfigs, relayer, @@ -373,7 +375,7 @@ func (i *pluginOracleCreator) createReadersAndWriters( return nil, nil, err1 } - if err4 := cw.Start(context.Background()); err4 != nil { + if err4 := cw.Start(ctx); err4 != nil { return nil, nil, fmt.Errorf("failed to start chain writer for chain %s: %w", chainID.String(), err4) } @@ -476,6 +478,7 @@ func isUSDCEnabled(ofc offChainConfig) bool { } func createChainWriter( + ctx context.Context, chainID *big.Int, evmConfigs toml.EVMConfigs, relayer loop.Relayer, @@ -509,8 +512,7 @@ func createChainWriter( return nil, fmt.Errorf("failed to marshal chain writer config: %w", err) } - // TODO: context. - cw, err := relayer.NewChainWriter(context.Background(), chainWriterConfig) + cw, err := relayer.NewChainWriter(ctx, chainWriterConfig) if err != nil { return nil, fmt.Errorf("failed to create chain writer for chain %s: %w", chainID.String(), err) } diff --git a/core/capabilities/ccip/types/mocks/oracle_creator.go b/core/capabilities/ccip/types/mocks/oracle_creator.go index 51103c4a504..1906df7e063 100644 --- a/core/capabilities/ccip/types/mocks/oracle_creator.go +++ b/core/capabilities/ccip/types/mocks/oracle_creator.go @@ -3,6 +3,8 @@ package mocks import ( + context "context" + types "github.com/smartcontractkit/chainlink/v2/core/capabilities/ccip/types" mock "github.com/stretchr/testify/mock" ) @@ -20,9 +22,9 @@ func (_m *OracleCreator) EXPECT() *OracleCreator_Expecter { return &OracleCreator_Expecter{mock: &_m.Mock} } -// Create provides a mock function with given fields: donID, config -func (_m *OracleCreator) Create(donID uint32, config types.OCR3ConfigWithMeta) (types.CCIPOracle, error) { - ret := _m.Called(donID, config) +// Create provides a mock function with given fields: ctx, donID, config +func (_m *OracleCreator) Create(ctx context.Context, donID uint32, config types.OCR3ConfigWithMeta) (types.CCIPOracle, error) { + ret := _m.Called(ctx, donID, config) if len(ret) == 0 { panic("no return value specified for Create") @@ -30,19 +32,19 @@ func (_m *OracleCreator) Create(donID uint32, config types.OCR3ConfigWithMeta) ( var r0 types.CCIPOracle var r1 error - if rf, ok := ret.Get(0).(func(uint32, types.OCR3ConfigWithMeta) (types.CCIPOracle, error)); ok { - return rf(donID, config) + if rf, ok := ret.Get(0).(func(context.Context, uint32, types.OCR3ConfigWithMeta) (types.CCIPOracle, error)); ok { + return rf(ctx, donID, config) } - if rf, ok := ret.Get(0).(func(uint32, types.OCR3ConfigWithMeta) types.CCIPOracle); ok { - r0 = rf(donID, config) + if rf, ok := ret.Get(0).(func(context.Context, uint32, types.OCR3ConfigWithMeta) types.CCIPOracle); ok { + r0 = rf(ctx, donID, config) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(types.CCIPOracle) } } - if rf, ok := ret.Get(1).(func(uint32, types.OCR3ConfigWithMeta) error); ok { - r1 = rf(donID, config) + if rf, ok := ret.Get(1).(func(context.Context, uint32, types.OCR3ConfigWithMeta) error); ok { + r1 = rf(ctx, donID, config) } else { r1 = ret.Error(1) } @@ -56,15 +58,16 @@ type OracleCreator_Create_Call struct { } // Create is a helper method to define mock.On call +// - ctx context.Context // - donID uint32 // - config types.OCR3ConfigWithMeta -func (_e *OracleCreator_Expecter) Create(donID interface{}, config interface{}) *OracleCreator_Create_Call { - return &OracleCreator_Create_Call{Call: _e.mock.On("Create", donID, config)} +func (_e *OracleCreator_Expecter) Create(ctx interface{}, donID interface{}, config interface{}) *OracleCreator_Create_Call { + return &OracleCreator_Create_Call{Call: _e.mock.On("Create", ctx, donID, config)} } -func (_c *OracleCreator_Create_Call) Run(run func(donID uint32, config types.OCR3ConfigWithMeta)) *OracleCreator_Create_Call { +func (_c *OracleCreator_Create_Call) Run(run func(ctx context.Context, donID uint32, config types.OCR3ConfigWithMeta)) *OracleCreator_Create_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(uint32), args[1].(types.OCR3ConfigWithMeta)) + run(args[0].(context.Context), args[1].(uint32), args[2].(types.OCR3ConfigWithMeta)) }) return _c } @@ -74,7 +77,7 @@ func (_c *OracleCreator_Create_Call) Return(_a0 types.CCIPOracle, _a1 error) *Or return _c } -func (_c *OracleCreator_Create_Call) RunAndReturn(run func(uint32, types.OCR3ConfigWithMeta) (types.CCIPOracle, error)) *OracleCreator_Create_Call { +func (_c *OracleCreator_Create_Call) RunAndReturn(run func(context.Context, uint32, types.OCR3ConfigWithMeta) (types.CCIPOracle, error)) *OracleCreator_Create_Call { _c.Call.Return(run) return _c } diff --git a/core/capabilities/ccip/types/types.go b/core/capabilities/ccip/types/types.go index 04da1157b33..8341adf2030 100644 --- a/core/capabilities/ccip/types/types.go +++ b/core/capabilities/ccip/types/types.go @@ -1,6 +1,8 @@ package types import ( + "context" + ccipreaderpkg "github.com/smartcontractkit/chainlink-ccip/pkg/reader" ) @@ -46,7 +48,7 @@ type OracleCreator interface { // Create creates a new oracle that will run either the commit or exec ccip plugin, // if its a plugin oracle, or a bootstrap oracle if its a bootstrap oracle. // The oracle must be returned unstarted. - Create(donID uint32, config OCR3ConfigWithMeta) (CCIPOracle, error) + Create(ctx context.Context, donID uint32, config OCR3ConfigWithMeta) (CCIPOracle, error) // Type returns the type of oracle that this creator creates. // The only valid values are OracleTypePlugin and OracleTypeBootstrap. diff --git a/core/chains/evm/monitor/balance.go b/core/chains/evm/monitor/balance.go index 1f5275c13fb..b6cb9adb875 100644 --- a/core/chains/evm/monitor/balance.go +++ b/core/chains/evm/monitor/balance.go @@ -65,13 +65,13 @@ func NewBalanceMonitor(ethClient evmclient.Client, ethKeyStore keystore.Eth, lgg Start: bm.start, Close: bm.close, }.NewServiceEngine(lggr) - bm.sleeperTask = utils.NewSleeperTask(&worker{bm: bm}) + bm.sleeperTask = utils.NewSleeperTaskCtx(&worker{bm: bm}) return bm } func (bm *balanceMonitor) start(ctx context.Context) error { // Always query latest balance on start - (&worker{bm}).WorkCtx(ctx) + (&worker{bm}).Work(ctx) return nil } @@ -146,12 +146,7 @@ func (*worker) Name() string { return "BalanceMonitorWorker" } -func (w *worker) Work() { - // Used with SleeperTask - w.WorkCtx(context.Background()) -} - -func (w *worker) WorkCtx(ctx context.Context) { +func (w *worker) Work(ctx context.Context) { enabledAddresses, err := w.bm.ethKeyStore.EnabledAddressesForChain(ctx, w.bm.chainID) if err != nil { w.bm.eng.Error("BalanceMonitor: error getting keys", err) diff --git a/core/services/chainlink/application.go b/core/services/chainlink/application.go index 112b87cf0af..2c918b3a8d8 100644 --- a/core/services/chainlink/application.go +++ b/core/services/chainlink/application.go @@ -376,7 +376,9 @@ func NewApplication(opts ApplicationOpts) (Application, error) { if err != nil { return nil, errors.Wrap(err, "NewApplication: failed to initialize LDAP Authentication module") } - sessionReaper = ldapauth.NewLDAPServerStateSync(opts.DS, cfg.WebServer().LDAP(), globalLogger) + syncer := ldapauth.NewLDAPServerStateSyncer(opts.DS, cfg.WebServer().LDAP(), globalLogger) + srvcs = append(srvcs, syncer) + sessionReaper = utils.NewSleeperTaskCtx(syncer) case sessions.LocalAuth: authenticationProvider = localauth.NewORM(opts.DS, cfg.WebServer().SessionTimeout().Duration(), globalLogger, auditLogger) sessionReaper = localauth.NewSessionReaper(opts.DS, cfg.WebServer(), globalLogger) diff --git a/core/services/llo/mercurytransmitter/persistence_manager.go b/core/services/llo/mercurytransmitter/persistence_manager.go index eb36a7d1b80..ffa82493c9c 100644 --- a/core/services/llo/mercurytransmitter/persistence_manager.go +++ b/core/services/llo/mercurytransmitter/persistence_manager.go @@ -78,7 +78,7 @@ func (pm *persistenceManager) Load(ctx context.Context) ([]*Transmission, error) func (pm *persistenceManager) runFlushDeletesLoop() { defer pm.wg.Done() - ctx, cancel := pm.stopCh.Ctx(context.Background()) + ctx, cancel := pm.stopCh.NewCtx() defer cancel() ticker := services.NewTicker(pm.flushDeletesFrequency) diff --git a/core/services/llo/mercurytransmitter/server.go b/core/services/llo/mercurytransmitter/server.go index 72ff8b669ba..70e76655961 100644 --- a/core/services/llo/mercurytransmitter/server.go +++ b/core/services/llo/mercurytransmitter/server.go @@ -106,7 +106,7 @@ func (s *server) HealthReport() map[string]error { func (s *server) runDeleteQueueLoop(stopCh services.StopChan, wg *sync.WaitGroup) { defer wg.Done() - runloopCtx, cancel := stopCh.Ctx(context.Background()) + ctx, cancel := stopCh.NewCtx() defer cancel() // Exponential backoff for very rarely occurring errors (DB disconnect etc) @@ -121,8 +121,8 @@ func (s *server) runDeleteQueueLoop(stopCh services.StopChan, wg *sync.WaitGroup select { case hash := <-s.deleteQueue: for { - if err := s.pm.orm.Delete(runloopCtx, [][32]byte{hash}); err != nil { - s.lggr.Errorw("Failed to delete transmission record", "err", err, "transmissionHash", fmt.Sprintf("%x", hash)) + if err := s.pm.orm.Delete(ctx, [][32]byte{hash}); err != nil { + s.lggr.Errorw("Failed to delete transmission record", "err", err, "transmissionHash", hash) s.transmitQueueDeleteErrorCount.Inc() select { case <-time.After(b.Duration()): @@ -154,7 +154,7 @@ func (s *server) runQueueLoop(stopCh services.StopChan, wg *sync.WaitGroup, donI Factor: 2, Jitter: true, } - runloopCtx, cancel := stopCh.Ctx(context.Background()) + ctx, cancel := stopCh.NewCtx() defer cancel() for { t := s.q.BlockingPop() @@ -162,12 +162,13 @@ func (s *server) runQueueLoop(stopCh services.StopChan, wg *sync.WaitGroup, donI // queue was closed return } - ctx, cancel := context.WithTimeout(runloopCtx, utils.WithJitter(s.transmitTimeout)) - res, err := s.transmit(ctx, t) - cancel() - if runloopCtx.Err() != nil { - // runloop context is only canceled on transmitter close so we can - // exit the runloop here + res, err := func(ctx context.Context) (*pb.TransmitResponse, error) { + ctx, cancelFn := context.WithTimeout(ctx, utils.WithJitter(s.transmitTimeout)) + defer cancelFn() + return s.transmit(ctx, t) + }(ctx) + if ctx.Err() != nil { + // only canceled on transmitter close so we can exit return } else if err != nil { s.transmitConnectionErrorCount.Inc() diff --git a/core/services/llo/onchain_channel_definition_cache.go b/core/services/llo/onchain_channel_definition_cache.go index 8467a84aaef..3613108d133 100644 --- a/core/services/llo/onchain_channel_definition_cache.go +++ b/core/services/llo/onchain_channel_definition_cache.go @@ -108,7 +108,7 @@ type channelDefinitionCache struct { persistedVersion uint32 wg sync.WaitGroup - chStop chan struct{} + chStop services.StopChan } type HTTPClient interface { @@ -180,7 +180,7 @@ func (c *channelDefinitionCache) Start(ctx context.Context) error { func (c *channelDefinitionCache) pollChainLoop() { defer c.wg.Done() - ctx, cancel := services.StopChan(c.chStop).NewCtx() + ctx, cancel := c.chStop.NewCtx() defer cancel() pollT := services.NewTicker(c.logPollInterval) @@ -353,7 +353,7 @@ func (c *channelDefinitionCache) fetchAndSetChannelDefinitions(ctx context.Conte c.definitionsVersion = log.Version c.definitionsMu.Unlock() - if memoryVersion, persistedVersion, err := c.persist(context.Background()); err != nil { + if memoryVersion, persistedVersion, err := c.persist(ctx); err != nil { // If this fails, the failedPersistLoop will try again c.lggr.Warnw("Failed to persist channel definitions", "err", err, "memoryVersion", memoryVersion, "persistedVersion", persistedVersion) } @@ -457,7 +457,7 @@ func (c *channelDefinitionCache) persist(ctx context.Context) (memoryVersion, pe func (c *channelDefinitionCache) failedPersistLoop() { defer c.wg.Done() - ctx, cancel := services.StopChan(c.chStop).NewCtx() + ctx, cancel := c.chStop.NewCtx() defer cancel() for { diff --git a/core/services/ocr2/plugins/ccip/exportinternal.go b/core/services/ocr2/plugins/ccip/exportinternal.go index aecf1a0b163..6b24cba4857 100644 --- a/core/services/ocr2/plugins/ccip/exportinternal.go +++ b/core/services/ocr2/plugins/ccip/exportinternal.go @@ -38,32 +38,32 @@ func NewEvmPriceRegistry(lp logpoller.LogPoller, ec client.Client, lggr logger.L type VersionFinder = factory.VersionFinder -func NewCommitStoreReader(lggr logger.Logger, versionFinder VersionFinder, address ccip.Address, ec client.Client, lp logpoller.LogPoller) (ccipdata.CommitStoreReader, error) { - return factory.NewCommitStoreReader(lggr, versionFinder, address, ec, lp) +func NewCommitStoreReader(ctx context.Context, lggr logger.Logger, versionFinder VersionFinder, address ccip.Address, ec client.Client, lp logpoller.LogPoller) (ccipdata.CommitStoreReader, error) { + return factory.NewCommitStoreReader(ctx, lggr, versionFinder, address, ec, lp) } -func CloseCommitStoreReader(lggr logger.Logger, versionFinder VersionFinder, address ccip.Address, ec client.Client, lp logpoller.LogPoller) error { - return factory.CloseCommitStoreReader(lggr, versionFinder, address, ec, lp) +func CloseCommitStoreReader(ctx context.Context, lggr logger.Logger, versionFinder VersionFinder, address ccip.Address, ec client.Client, lp logpoller.LogPoller) error { + return factory.CloseCommitStoreReader(ctx, lggr, versionFinder, address, ec, lp) } -func NewOffRampReader(lggr logger.Logger, versionFinder VersionFinder, addr ccip.Address, destClient client.Client, lp logpoller.LogPoller, estimator gas.EvmFeeEstimator, destMaxGasPrice *big.Int, registerFilters bool) (ccipdata.OffRampReader, error) { - return factory.NewOffRampReader(lggr, versionFinder, addr, destClient, lp, estimator, destMaxGasPrice, registerFilters) +func NewOffRampReader(ctx context.Context, lggr logger.Logger, versionFinder VersionFinder, addr ccip.Address, destClient client.Client, lp logpoller.LogPoller, estimator gas.EvmFeeEstimator, destMaxGasPrice *big.Int, registerFilters bool) (ccipdata.OffRampReader, error) { + return factory.NewOffRampReader(ctx, lggr, versionFinder, addr, destClient, lp, estimator, destMaxGasPrice, registerFilters) } -func CloseOffRampReader(lggr logger.Logger, versionFinder VersionFinder, addr ccip.Address, destClient client.Client, lp logpoller.LogPoller, estimator gas.EvmFeeEstimator, destMaxGasPrice *big.Int) error { - return factory.CloseOffRampReader(lggr, versionFinder, addr, destClient, lp, estimator, destMaxGasPrice) +func CloseOffRampReader(ctx context.Context, lggr logger.Logger, versionFinder VersionFinder, addr ccip.Address, destClient client.Client, lp logpoller.LogPoller, estimator gas.EvmFeeEstimator, destMaxGasPrice *big.Int) error { + return factory.CloseOffRampReader(ctx, lggr, versionFinder, addr, destClient, lp, estimator, destMaxGasPrice) } func NewEvmVersionFinder() factory.EvmVersionFinder { return factory.NewEvmVersionFinder() } -func NewOnRampReader(lggr logger.Logger, versionFinder VersionFinder, sourceSelector, destSelector uint64, onRampAddress ccip.Address, sourceLP logpoller.LogPoller, source client.Client) (ccipdata.OnRampReader, error) { - return factory.NewOnRampReader(lggr, versionFinder, sourceSelector, destSelector, onRampAddress, sourceLP, source) +func NewOnRampReader(ctx context.Context, lggr logger.Logger, versionFinder VersionFinder, sourceSelector, destSelector uint64, onRampAddress ccip.Address, sourceLP logpoller.LogPoller, source client.Client) (ccipdata.OnRampReader, error) { + return factory.NewOnRampReader(ctx, lggr, versionFinder, sourceSelector, destSelector, onRampAddress, sourceLP, source) } -func CloseOnRampReader(lggr logger.Logger, versionFinder VersionFinder, sourceSelector, destSelector uint64, onRampAddress ccip.Address, sourceLP logpoller.LogPoller, source client.Client) error { - return factory.CloseOnRampReader(lggr, versionFinder, sourceSelector, destSelector, onRampAddress, sourceLP, source) +func CloseOnRampReader(ctx context.Context, lggr logger.Logger, versionFinder VersionFinder, sourceSelector, destSelector uint64, onRampAddress ccip.Address, sourceLP logpoller.LogPoller, source client.Client) error { + return factory.CloseOnRampReader(ctx, lggr, versionFinder, sourceSelector, destSelector, onRampAddress, sourceLP, source) } type OffRampReader = ccipdata.OffRampReader @@ -86,12 +86,12 @@ func NewDynamicLimitedBatchCaller( return rpclib.NewDynamicLimitedBatchCaller(lggr, batchSender, batchSizeLimit, backOffMultiplier, parallelRpcCallsLimit) } -func NewUSDCReader(lggr logger.Logger, jobID string, transmitter common.Address, lp logpoller.LogPoller, registerFilters bool) (*ccipdata.USDCReaderImpl, error) { - return ccipdata.NewUSDCReader(lggr, jobID, transmitter, lp, registerFilters) +func NewUSDCReader(ctx context.Context, lggr logger.Logger, jobID string, transmitter common.Address, lp logpoller.LogPoller, registerFilters bool) (*ccipdata.USDCReaderImpl, error) { + return ccipdata.NewUSDCReader(ctx, lggr, jobID, transmitter, lp, registerFilters) } -func CloseUSDCReader(lggr logger.Logger, jobID string, transmitter common.Address, lp logpoller.LogPoller) error { - return ccipdata.CloseUSDCReader(lggr, jobID, transmitter, lp) +func CloseUSDCReader(ctx context.Context, lggr logger.Logger, jobID string, transmitter common.Address, lp logpoller.LogPoller) error { + return ccipdata.CloseUSDCReader(ctx, lggr, jobID, transmitter, lp) } type USDCReaderImpl = ccipdata.USDCReaderImpl diff --git a/core/services/ocr2/plugins/ccip/internal/cache/chain_health.go b/core/services/ocr2/plugins/ccip/internal/cache/chain_health.go index 00f90615eb2..b029ee02132 100644 --- a/core/services/ocr2/plugins/ccip/internal/cache/chain_health.go +++ b/core/services/ocr2/plugins/ccip/internal/cache/chain_health.go @@ -57,15 +57,12 @@ type chainHealthcheck struct { commitStore ccipdata.CommitStoreReader services.StateMachine - wg *sync.WaitGroup - backgroundCtx context.Context //nolint:containedctx - backgroundCancel context.CancelFunc + wg sync.WaitGroup + stopChan services.StopChan } func NewChainHealthcheck(lggr logger.Logger, onRamp ccipdata.OnRampReader, commitStore ccipdata.CommitStoreReader) *chainHealthcheck { - ctx, cancel := context.WithCancel(context.Background()) - - ch := &chainHealthcheck{ + return &chainHealthcheck{ // Different keys use different expiration times, so we don't need to worry about the default value cache: cache.New(cache.NoExpiration, 0), rmnStatusKey: rmnStatusKey, @@ -76,18 +73,12 @@ func NewChainHealthcheck(lggr logger.Logger, onRamp ccipdata.OnRampReader, commi lggr: lggr, onRamp: onRamp, commitStore: commitStore, - - wg: new(sync.WaitGroup), - backgroundCtx: ctx, - backgroundCancel: cancel, + stopChan: make(services.StopChan), } - return ch } // newChainHealthcheckWithCustomEviction is used for testing purposes only. It doesn't start background worker func newChainHealthcheckWithCustomEviction(lggr logger.Logger, onRamp ccipdata.OnRampReader, commitStore ccipdata.CommitStoreReader, globalStatusDuration time.Duration, rmnStatusRefreshInterval time.Duration) *chainHealthcheck { - ctx, cancel := context.WithCancel(context.Background()) - return &chainHealthcheck{ cache: cache.New(rmnStatusRefreshInterval, 0), rmnStatusKey: rmnStatusKey, @@ -98,10 +89,7 @@ func newChainHealthcheckWithCustomEviction(lggr logger.Logger, onRamp ccipdata.O lggr: lggr, onRamp: onRamp, commitStore: commitStore, - - wg: new(sync.WaitGroup), - backgroundCtx: ctx, - backgroundCancel: cancel, + stopChan: make(services.StopChan), } } @@ -145,7 +133,6 @@ func (c *chainHealthcheck) IsHealthy(ctx context.Context) (bool, error) { func (c *chainHealthcheck) Start(context.Context) error { return c.StateMachine.StartOnce("ChainHealthcheck", func() error { c.lggr.Info("Starting ChainHealthcheck") - c.wg.Add(1) c.run() return nil }) @@ -154,7 +141,7 @@ func (c *chainHealthcheck) Start(context.Context) error { func (c *chainHealthcheck) Close() error { return c.StateMachine.StopOnce("ChainHealthcheck", func() error { c.lggr.Info("Closing ChainHealthcheck") - c.backgroundCancel() + close(c.stopChan) c.wg.Wait() return nil }) @@ -162,17 +149,20 @@ func (c *chainHealthcheck) Close() error { func (c *chainHealthcheck) run() { ticker := time.NewTicker(c.rmnStatusRefreshInterval) + c.wg.Add(1) go func() { defer c.wg.Done() + ctx, cancel := c.stopChan.NewCtx() + defer cancel() // Refresh the RMN state immediately after starting the background refresher - _, _ = c.refresh(c.backgroundCtx) + _, _ = c.refresh(ctx) for { select { - case <-c.backgroundCtx.Done(): + case <-ctx.Done(): return case <-ticker.C: - _, err := c.refresh(c.backgroundCtx) + _, err := c.refresh(ctx) if err != nil { c.lggr.Errorw("Failed to refresh RMN state in the background", "err", err) } diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/commit_store_reader_test.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/commit_store_reader_test.go index f46b1b55b1f..0f234bab8a6 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/commit_store_reader_test.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/commit_store_reader_test.go @@ -1,7 +1,6 @@ package ccipdata_test import ( - "context" "math/big" "reflect" "testing" @@ -15,6 +14,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/logger" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccip" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" evmclientmocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client/mocks" @@ -180,7 +180,7 @@ func TestCommitStoreReaders(t *testing.T) { ge.On("L1Oracle").Return(lm) maxGasPrice := big.NewInt(1e8) - c12r, err := factory.NewCommitStoreReader(lggr, factory.NewEvmVersionFinder(), ccipcalc.EvmAddrToGeneric(addr2), ec, lp) + c12r, err := factory.NewCommitStoreReader(ctx, lggr, factory.NewEvmVersionFinder(), ccipcalc.EvmAddrToGeneric(addr2), ec, lp) require.NoError(t, err) err = c12r.SetGasEstimator(ctx, ge) require.NoError(t, err) @@ -228,7 +228,7 @@ func TestCommitStoreReaders(t *testing.T) { commitAndGetBlockTs(ec) // Capture all logs. - lp.PollAndSaveLogs(context.Background(), 1) + lp.PollAndSaveLogs(ctx, 1) configs := map[string][][]byte{ ccipdata.V1_2_0: {onchainConfig2, offchainConfig2}, @@ -248,7 +248,7 @@ func TestCommitStoreReaders(t *testing.T) { cr := cr t.Run("CommitStoreReader "+v, func(t *testing.T) { // Static config. - cfg, err := cr.GetCommitStoreStaticConfig(context.Background()) + cfg, err := cr.GetCommitStoreStaticConfig(ctx) require.NoError(t, err) require.NotNil(t, cfg) @@ -260,33 +260,33 @@ func TestCommitStoreReaders(t *testing.T) { assert.Equal(t, d, rep) // Assert reading - latest, err := cr.GetLatestPriceEpochAndRound(context.Background()) + latest, err := cr.GetLatestPriceEpochAndRound(ctx) require.NoError(t, err) assert.Equal(t, er.Uint64(), latest) // Assert cursing - down, err := cr.IsDown(context.Background()) + down, err := cr.IsDown(ctx) require.NoError(t, err) assert.False(t, down) _, err = arm.VoteToCurse(user, [32]byte{}) require.NoError(t, err) ec.Commit() - down, err = cr.IsDown(context.Background()) + down, err = cr.IsDown(ctx) require.NoError(t, err) assert.True(t, down) _, err = arm.OwnerUnvoteToCurse0(user, nil) require.NoError(t, err) ec.Commit() - seqNr, err := cr.GetExpectedNextSequenceNumber(context.Background()) + seqNr, err := cr.GetExpectedNextSequenceNumber(ctx) require.NoError(t, err) assert.Equal(t, rep.Interval.Max+1, seqNr) - reps, err := cr.GetCommitReportMatchingSeqNum(context.Background(), rep.Interval.Max+1, 0) + reps, err := cr.GetCommitReportMatchingSeqNum(ctx, rep.Interval.Max+1, 0) require.NoError(t, err) assert.Len(t, reps, 0) - reps, err = cr.GetCommitReportMatchingSeqNum(context.Background(), rep.Interval.Max, 0) + reps, err = cr.GetCommitReportMatchingSeqNum(ctx, rep.Interval.Max, 0) require.NoError(t, err) require.Len(t, reps, 1) assert.Equal(t, reps[0].Interval, rep.Interval) @@ -294,7 +294,7 @@ func TestCommitStoreReaders(t *testing.T) { assert.Equal(t, reps[0].GasPrices, rep.GasPrices) assert.Equal(t, reps[0].TokenPrices, rep.TokenPrices) - reps, err = cr.GetCommitReportMatchingSeqNum(context.Background(), rep.Interval.Min, 0) + reps, err = cr.GetCommitReportMatchingSeqNum(ctx, rep.Interval.Min, 0) require.NoError(t, err) require.Len(t, reps, 1) assert.Equal(t, reps[0].Interval, rep.Interval) @@ -302,12 +302,12 @@ func TestCommitStoreReaders(t *testing.T) { assert.Equal(t, reps[0].GasPrices, rep.GasPrices) assert.Equal(t, reps[0].TokenPrices, rep.TokenPrices) - reps, err = cr.GetCommitReportMatchingSeqNum(context.Background(), rep.Interval.Min-1, 0) + reps, err = cr.GetCommitReportMatchingSeqNum(ctx, rep.Interval.Min-1, 0) require.NoError(t, err) require.Len(t, reps, 0) // Sanity - reps, err = cr.GetAcceptedCommitReportsGteTimestamp(context.Background(), time.Unix(0, 0), 0) + reps, err = cr.GetAcceptedCommitReportsGteTimestamp(ctx, time.Unix(0, 0), 0) require.NoError(t, err) require.Len(t, reps, 1) assert.Equal(t, reps[0].Interval, rep.Interval) @@ -329,7 +329,7 @@ func TestCommitStoreReaders(t *testing.T) { // We should be able to query for gas prices now. gpe, err := cr.GasPriceEstimator(ctx) require.NoError(t, err) - gp, err := gpe.GetGasPrice(context.Background()) + gp, err := gpe.GetGasPrice(ctx) require.NoError(t, err) assert.True(t, gp.Cmp(big.NewInt(0)) > 0) }) @@ -360,6 +360,7 @@ func TestNewCommitStoreReader(t *testing.T) { } for _, tc := range tt { t.Run(tc.typeAndVersion, func(t *testing.T) { + ctx := tests.Context(t) b, err := utils.ABIEncode(`[{"type":"string"}]`, tc.typeAndVersion) require.NoError(t, err) c := evmclientmocks.NewClient(t) @@ -369,7 +370,7 @@ func TestNewCommitStoreReader(t *testing.T) { if tc.expectedErr == "" { lp.On("RegisterFilter", mock.Anything, mock.Anything).Return(nil) } - _, err = factory.NewCommitStoreReader(logger.Test(t), factory.NewEvmVersionFinder(), addr, c, lp) + _, err = factory.NewCommitStoreReader(ctx, logger.Test(t), factory.NewEvmVersionFinder(), addr, c, lp) if tc.expectedErr != "" { require.EqualError(t, err, tc.expectedErr) } else { diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/commit_store.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/commit_store.go index ec4cdded9a7..d9cd523d75e 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/commit_store.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/commit_store.go @@ -1,6 +1,8 @@ package factory import ( + "context" + "github.com/Masterminds/semver/v3" "github.com/pkg/errors" @@ -19,16 +21,16 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_5_0" ) -func NewCommitStoreReader(lggr logger.Logger, versionFinder VersionFinder, address cciptypes.Address, ec client.Client, lp logpoller.LogPoller) (ccipdata.CommitStoreReader, error) { - return initOrCloseCommitStoreReader(lggr, versionFinder, address, ec, lp, false) +func NewCommitStoreReader(ctx context.Context, lggr logger.Logger, versionFinder VersionFinder, address cciptypes.Address, ec client.Client, lp logpoller.LogPoller) (ccipdata.CommitStoreReader, error) { + return initOrCloseCommitStoreReader(ctx, lggr, versionFinder, address, ec, lp, false) } -func CloseCommitStoreReader(lggr logger.Logger, versionFinder VersionFinder, address cciptypes.Address, ec client.Client, lp logpoller.LogPoller) error { - _, err := initOrCloseCommitStoreReader(lggr, versionFinder, address, ec, lp, true) +func CloseCommitStoreReader(ctx context.Context, lggr logger.Logger, versionFinder VersionFinder, address cciptypes.Address, ec client.Client, lp logpoller.LogPoller) error { + _, err := initOrCloseCommitStoreReader(ctx, lggr, versionFinder, address, ec, lp, true) return err } -func initOrCloseCommitStoreReader(lggr logger.Logger, versionFinder VersionFinder, address cciptypes.Address, ec client.Client, lp logpoller.LogPoller, closeReader bool) (ccipdata.CommitStoreReader, error) { +func initOrCloseCommitStoreReader(ctx context.Context, lggr logger.Logger, versionFinder VersionFinder, address cciptypes.Address, ec client.Client, lp logpoller.LogPoller, closeReader bool) (ccipdata.CommitStoreReader, error) { contractType, version, err := versionFinder.TypeAndVersion(address, ec) if err != nil { return nil, errors.Wrapf(err, "unable to read type and version") @@ -53,7 +55,7 @@ func initOrCloseCommitStoreReader(lggr logger.Logger, versionFinder VersionFinde if closeReader { return nil, cs.Close() } - return cs, cs.RegisterFilters() + return cs, cs.RegisterFilters(ctx) case ccipdata.V1_5_0: cs, err := v1_5_0.NewCommitStore(lggr, evmAddr, ec, lp) if err != nil { @@ -62,7 +64,7 @@ func initOrCloseCommitStoreReader(lggr logger.Logger, versionFinder VersionFinde if closeReader { return nil, cs.Close() } - return cs, cs.RegisterFilters() + return cs, cs.RegisterFilters(ctx) default: return nil, errors.Errorf("unsupported commit store version %v", version.String()) } diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/commit_store_test.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/commit_store_test.go index 6beb6953d1a..cd81a0633ce 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/commit_store_test.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/commit_store_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/mock" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccip" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -20,6 +21,7 @@ import ( ) func TestCommitStore(t *testing.T) { + ctx := tests.Context(t) for _, versionStr := range []string{ccipdata.V1_2_0} { lggr := logger.Test(t) addr := cciptypes.Address(utils.RandomAddress().String()) @@ -27,12 +29,12 @@ func TestCommitStore(t *testing.T) { lp.On("RegisterFilter", mock.Anything, mock.Anything).Return(nil) versionFinder := newMockVersionFinder(ccipconfig.CommitStore, *semver.MustParse(versionStr), nil) - _, err := NewCommitStoreReader(lggr, versionFinder, addr, nil, lp) + _, err := NewCommitStoreReader(ctx, lggr, versionFinder, addr, nil, lp) assert.NoError(t, err) expFilterName := logpoller.FilterName(v1_2_0.ExecReportAccepts, addr) lp.On("UnregisterFilter", mock.Anything, expFilterName).Return(nil) - err = CloseCommitStoreReader(lggr, versionFinder, addr, nil, lp) + err = CloseCommitStoreReader(ctx, lggr, versionFinder, addr, nil, lp) assert.NoError(t, err) } } diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/offramp.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/offramp.go index f0f26fb37f3..136079b5b3e 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/offramp.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/offramp.go @@ -24,16 +24,16 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_5_0" ) -func NewOffRampReader(lggr logger.Logger, versionFinder VersionFinder, addr cciptypes.Address, destClient client.Client, lp logpoller.LogPoller, estimator gas.EvmFeeEstimator, destMaxGasPrice *big.Int, registerFilters bool) (ccipdata.OffRampReader, error) { - return initOrCloseOffRampReader(lggr, versionFinder, addr, destClient, lp, estimator, destMaxGasPrice, false, registerFilters) +func NewOffRampReader(ctx context.Context, lggr logger.Logger, versionFinder VersionFinder, addr cciptypes.Address, destClient client.Client, lp logpoller.LogPoller, estimator gas.EvmFeeEstimator, destMaxGasPrice *big.Int, registerFilters bool) (ccipdata.OffRampReader, error) { + return initOrCloseOffRampReader(ctx, lggr, versionFinder, addr, destClient, lp, estimator, destMaxGasPrice, false, registerFilters) } -func CloseOffRampReader(lggr logger.Logger, versionFinder VersionFinder, addr cciptypes.Address, destClient client.Client, lp logpoller.LogPoller, estimator gas.EvmFeeEstimator, destMaxGasPrice *big.Int) error { - _, err := initOrCloseOffRampReader(lggr, versionFinder, addr, destClient, lp, estimator, destMaxGasPrice, true, false) +func CloseOffRampReader(ctx context.Context, lggr logger.Logger, versionFinder VersionFinder, addr cciptypes.Address, destClient client.Client, lp logpoller.LogPoller, estimator gas.EvmFeeEstimator, destMaxGasPrice *big.Int) error { + _, err := initOrCloseOffRampReader(ctx, lggr, versionFinder, addr, destClient, lp, estimator, destMaxGasPrice, true, false) return err } -func initOrCloseOffRampReader(lggr logger.Logger, versionFinder VersionFinder, addr cciptypes.Address, destClient client.Client, lp logpoller.LogPoller, estimator gas.EvmFeeEstimator, destMaxGasPrice *big.Int, closeReader bool, registerFilters bool) (ccipdata.OffRampReader, error) { +func initOrCloseOffRampReader(ctx context.Context, lggr logger.Logger, versionFinder VersionFinder, addr cciptypes.Address, destClient client.Client, lp logpoller.LogPoller, estimator gas.EvmFeeEstimator, destMaxGasPrice *big.Int, closeReader bool, registerFilters bool) (ccipdata.OffRampReader, error) { contractType, version, err := versionFinder.TypeAndVersion(addr, destClient) if err != nil { return nil, errors.Wrapf(err, "unable to read type and version") @@ -58,7 +58,7 @@ func initOrCloseOffRampReader(lggr logger.Logger, versionFinder VersionFinder, a if closeReader { return nil, offRamp.Close() } - return offRamp, offRamp.RegisterFilters() + return offRamp, offRamp.RegisterFilters(ctx) case ccipdata.V1_5_0: offRamp, err := v1_5_0.NewOffRamp(lggr, evmAddr, destClient, lp, estimator, destMaxGasPrice) if err != nil { @@ -67,7 +67,7 @@ func initOrCloseOffRampReader(lggr logger.Logger, versionFinder VersionFinder, a if closeReader { return nil, offRamp.Close() } - return offRamp, offRamp.RegisterFilters() + return offRamp, offRamp.RegisterFilters(ctx) default: return nil, errors.Errorf("unsupported offramp version %v", version.String()) } diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/offramp_test.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/offramp_test.go index 1851a6fb612..bfb8da5e32c 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/offramp_test.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/offramp_test.go @@ -9,6 +9,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccip" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" mocks2 "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller/mocks" @@ -19,6 +20,7 @@ import ( ) func TestOffRamp(t *testing.T) { + ctx := tests.Context(t) for _, versionStr := range []string{ccipdata.V1_2_0} { lggr := logger.Test(t) addr := cciptypes.Address(utils.RandomAddress().String()) @@ -32,13 +34,13 @@ func TestOffRamp(t *testing.T) { versionFinder := newMockVersionFinder(ccipconfig.EVM2EVMOffRamp, *semver.MustParse(versionStr), nil) lp.On("RegisterFilter", mock.Anything, mock.Anything).Return(nil).Times(len(expFilterNames)) - _, err := NewOffRampReader(lggr, versionFinder, addr, nil, lp, nil, nil, true) + _, err := NewOffRampReader(ctx, lggr, versionFinder, addr, nil, lp, nil, nil, true) assert.NoError(t, err) for _, f := range expFilterNames { lp.On("UnregisterFilter", mock.Anything, f).Return(nil) } - err = CloseOffRampReader(lggr, versionFinder, addr, nil, lp, nil, nil) + err = CloseOffRampReader(ctx, lggr, versionFinder, addr, nil, lp, nil, nil) assert.NoError(t, err) } } diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/onramp.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/onramp.go index e04a34f72de..57bf6e2eeb3 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/onramp.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/onramp.go @@ -1,6 +1,8 @@ package factory import ( + "context" + "github.com/pkg/errors" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -17,16 +19,16 @@ import ( ) // NewOnRampReader determines the appropriate version of the onramp and returns a reader for it -func NewOnRampReader(lggr logger.Logger, versionFinder VersionFinder, sourceSelector, destSelector uint64, onRampAddress cciptypes.Address, sourceLP logpoller.LogPoller, source client.Client) (ccipdata.OnRampReader, error) { - return initOrCloseOnRampReader(lggr, versionFinder, sourceSelector, destSelector, onRampAddress, sourceLP, source, false) +func NewOnRampReader(ctx context.Context, lggr logger.Logger, versionFinder VersionFinder, sourceSelector, destSelector uint64, onRampAddress cciptypes.Address, sourceLP logpoller.LogPoller, source client.Client) (ccipdata.OnRampReader, error) { + return initOrCloseOnRampReader(ctx, lggr, versionFinder, sourceSelector, destSelector, onRampAddress, sourceLP, source, false) } -func CloseOnRampReader(lggr logger.Logger, versionFinder VersionFinder, sourceSelector, destSelector uint64, onRampAddress cciptypes.Address, sourceLP logpoller.LogPoller, source client.Client) error { - _, err := initOrCloseOnRampReader(lggr, versionFinder, sourceSelector, destSelector, onRampAddress, sourceLP, source, true) +func CloseOnRampReader(ctx context.Context, lggr logger.Logger, versionFinder VersionFinder, sourceSelector, destSelector uint64, onRampAddress cciptypes.Address, sourceLP logpoller.LogPoller, source client.Client) error { + _, err := initOrCloseOnRampReader(ctx, lggr, versionFinder, sourceSelector, destSelector, onRampAddress, sourceLP, source, true) return err } -func initOrCloseOnRampReader(lggr logger.Logger, versionFinder VersionFinder, sourceSelector, destSelector uint64, onRampAddress cciptypes.Address, sourceLP logpoller.LogPoller, source client.Client, closeReader bool) (ccipdata.OnRampReader, error) { +func initOrCloseOnRampReader(ctx context.Context, lggr logger.Logger, versionFinder VersionFinder, sourceSelector, destSelector uint64, onRampAddress cciptypes.Address, sourceLP logpoller.LogPoller, source client.Client, closeReader bool) (ccipdata.OnRampReader, error) { contractType, version, err := versionFinder.TypeAndVersion(onRampAddress, source) if err != nil { return nil, errors.Wrapf(err, "unable to read type and version") @@ -51,7 +53,7 @@ func initOrCloseOnRampReader(lggr logger.Logger, versionFinder VersionFinder, so if closeReader { return nil, onRamp.Close() } - return onRamp, onRamp.RegisterFilters() + return onRamp, onRamp.RegisterFilters(ctx) case ccipdata.V1_5_0: onRamp, err := v1_5_0.NewOnRamp(lggr, sourceSelector, destSelector, onRampAddrEvm, sourceLP, source) if err != nil { @@ -60,7 +62,7 @@ func initOrCloseOnRampReader(lggr logger.Logger, versionFinder VersionFinder, so if closeReader { return nil, onRamp.Close() } - return onRamp, onRamp.RegisterFilters() + return onRamp, onRamp.RegisterFilters(ctx) // Adding a new version? // Please update the public factory function in leafer.go if the new version updates the leaf hash function. default: diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/onramp_test.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/onramp_test.go index 320c8d6c301..bc1351f97c9 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/onramp_test.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/onramp_test.go @@ -9,6 +9,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccip" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" mocks2 "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller/mocks" @@ -18,6 +19,7 @@ import ( ) func TestOnRamp(t *testing.T) { + ctx := tests.Context(t) for _, versionStr := range []string{ccipdata.V1_2_0, ccipdata.V1_5_0} { lggr := logger.Test(t) addr := cciptypes.Address(utils.RandomAddress().String()) @@ -33,13 +35,13 @@ func TestOnRamp(t *testing.T) { versionFinder := newMockVersionFinder(ccipconfig.EVM2EVMOnRamp, *semver.MustParse(versionStr), nil) lp.On("RegisterFilter", mock.Anything, mock.Anything).Return(nil).Times(len(expFilterNames)) - _, err := NewOnRampReader(lggr, versionFinder, sourceSelector, destSelector, addr, lp, nil) + _, err := NewOnRampReader(ctx, lggr, versionFinder, sourceSelector, destSelector, addr, lp, nil) assert.NoError(t, err) for _, f := range expFilterNames { lp.On("UnregisterFilter", mock.Anything, f).Return(nil) } - err = CloseOnRampReader(lggr, versionFinder, sourceSelector, destSelector, addr, lp, nil) + err = CloseOnRampReader(ctx, lggr, versionFinder, sourceSelector, destSelector, addr, lp, nil) assert.NoError(t, err) } } diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/price_registry.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/price_registry.go index cb82e7273bf..90a40eee1a5 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/price_registry.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/factory/price_registry.go @@ -43,7 +43,7 @@ func initOrClosePriceRegistryReader(ctx context.Context, lggr logger.Logger, ver } switch version.String() { case ccipdata.V1_2_0: - pr, err := v1_2_0.NewPriceRegistry(lggr, priceRegistryEvmAddr, lp, cl, registerFilters) + pr, err := v1_2_0.NewPriceRegistry(ctx, lggr, priceRegistryEvmAddr, lp, cl, registerFilters) if err != nil { return nil, err } @@ -52,7 +52,7 @@ func initOrClosePriceRegistryReader(ctx context.Context, lggr logger.Logger, ver } return pr, nil case ccipdata.V1_6_0: - pr, err := v1_2_0.NewPriceRegistry(lggr, priceRegistryEvmAddr, lp, cl, registerFilters) + pr, err := v1_2_0.NewPriceRegistry(ctx, lggr, priceRegistryEvmAddr, lp, cl, registerFilters) if err != nil { return nil, err } diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader_test.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader_test.go index d0b3fe53436..17f9bcfb370 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader_test.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/offramp_reader_test.go @@ -14,6 +14,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccip" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" evmclientmocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client/mocks" @@ -139,7 +140,7 @@ func setupOffRampReaderTH(t *testing.T, version string) offRampReaderTH { } // Create the version-specific reader. - reader, err := factory.NewOffRampReader(log, factory.NewEvmVersionFinder(), ccipcalc.EvmAddrToGeneric(offRampAddress), bc, lp, nil, nil, true) + reader, err := factory.NewOffRampReader(ctx, log, factory.NewEvmVersionFinder(), ccipcalc.EvmAddrToGeneric(offRampAddress), bc, lp, nil, nil, true) require.NoError(t, err) addr, err := reader.Address(ctx) require.NoError(t, err) @@ -306,6 +307,7 @@ func TestNewOffRampReader(t *testing.T) { } for _, tc := range tt { t.Run(tc.typeAndVersion, func(t *testing.T) { + ctx := tests.Context(t) b, err := utils.ABIEncode(`[{"type":"string"}]`, tc.typeAndVersion) require.NoError(t, err) c := evmclientmocks.NewClient(t) @@ -313,7 +315,7 @@ func TestNewOffRampReader(t *testing.T) { addr := ccipcalc.EvmAddrToGeneric(utils.RandomAddress()) lp := lpmocks.NewLogPoller(t) lp.On("RegisterFilter", mock.Anything, mock.Anything).Return(nil).Maybe() - _, err = factory.NewOffRampReader(logger.Test(t), factory.NewEvmVersionFinder(), addr, c, lp, nil, nil, true) + _, err = factory.NewOffRampReader(ctx, logger.Test(t), factory.NewEvmVersionFinder(), addr, c, lp, nil, nil, true) if tc.expectedErr != "" { assert.EqualError(t, err, tc.expectedErr) } else { diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/onramp_reader_test.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/onramp_reader_test.go index db2e54f96ba..6340eb21682 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/onramp_reader_test.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/onramp_reader_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/require" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccip" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -37,9 +38,10 @@ type onRampReaderTH struct { } func TestNewOnRampReader_noContractAtAddress(t *testing.T) { + ctx := tests.Context(t) _, bc := ccipdata.NewSimulation(t) addr := ccipcalc.EvmAddrToGeneric(utils.RandomAddress()) - _, err := factory.NewOnRampReader(logger.Test(t), factory.NewEvmVersionFinder(), testutils.SimulatedChainID.Uint64(), testutils.SimulatedChainID.Uint64(), addr, lpmocks.NewLogPoller(t), bc) + _, err := factory.NewOnRampReader(ctx, logger.Test(t), factory.NewEvmVersionFinder(), testutils.SimulatedChainID.Uint64(), testutils.SimulatedChainID.Uint64(), addr, lpmocks.NewLogPoller(t), bc) assert.EqualError(t, err, fmt.Sprintf("unable to read type and version: error calling typeAndVersion on addr: %s no contract code at given address", addr)) } @@ -67,6 +69,7 @@ func TestOnRampReaderInit(t *testing.T) { } func setupOnRampReaderTH(t *testing.T, version string) onRampReaderTH { + ctx := tests.Context(t) user, bc := ccipdata.NewSimulation(t) log := logger.Test(t) orm := logpoller.NewORM(testutils.SimulatedChainID, pgtest.NewSqlxDB(t), log) @@ -100,7 +103,7 @@ func setupOnRampReaderTH(t *testing.T, version string) onRampReaderTH { } // Create the version-specific reader. - reader, err := factory.NewOnRampReader(log, factory.NewEvmVersionFinder(), testutils.SimulatedChainID.Uint64(), testutils.SimulatedChainID.Uint64(), ccipcalc.EvmAddrToGeneric(onRampAddress), lp, bc) + reader, err := factory.NewOnRampReader(ctx, log, factory.NewEvmVersionFinder(), testutils.SimulatedChainID.Uint64(), testutils.SimulatedChainID.Uint64(), ccipcalc.EvmAddrToGeneric(onRampAddress), lp, bc) require.NoError(t, err) return onRampReaderTH{ @@ -309,6 +312,7 @@ func TestNewOnRampReader(t *testing.T) { } for _, tc := range tt { t.Run(tc.typeAndVersion, func(t *testing.T) { + ctx := tests.Context(t) b, err := utils.ABIEncode(`[{"type":"string"}]`, tc.typeAndVersion) require.NoError(t, err) c := evmclientmocks.NewClient(t) @@ -316,7 +320,7 @@ func TestNewOnRampReader(t *testing.T) { addr := ccipcalc.EvmAddrToGeneric(utils.RandomAddress()) lp := lpmocks.NewLogPoller(t) lp.On("RegisterFilter", mock.Anything, mock.Anything).Return(nil).Maybe() - _, err = factory.NewOnRampReader(logger.Test(t), factory.NewEvmVersionFinder(), 1, 2, addr, lp, c) + _, err = factory.NewOnRampReader(ctx, logger.Test(t), factory.NewEvmVersionFinder(), 1, 2, addr, lp, c) if tc.expectedErr != "" { require.EqualError(t, err, tc.expectedErr) } else { diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/price_registry_reader_test.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/price_registry_reader_test.go index 1f8d48ddfee..f5b97926b6e 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/price_registry_reader_test.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/price_registry_reader_test.go @@ -2,6 +2,7 @@ package ccipdata_test import ( "context" + "math" "math/big" "reflect" "testing" @@ -137,7 +138,7 @@ func setupPriceRegistryReaderTH(t *testing.T) priceRegReaderTH { b2 := commitAndGetBlockTs(ec) // Capture all lp data. - lp.PollAndSaveLogs(context.Background(), 1) + lp.PollAndSaveLogs(ctx, 1) return priceRegReaderTH{ lp: lp, @@ -162,15 +163,16 @@ func setupPriceRegistryReaderTH(t *testing.T) priceRegReaderTH { } func testPriceRegistryReader(t *testing.T, th priceRegReaderTH, pr ccipdata.PriceRegistryReader) { + ctx := testutils.Context(t) // Assert have expected fee tokens. - gotFeeTokens, err := pr.GetFeeTokens(context.Background()) + gotFeeTokens, err := pr.GetFeeTokens(ctx) require.NoError(t, err) evmAddrs, err := ccipcalc.GenericAddrsToEvm(gotFeeTokens...) require.NoError(t, err) assert.Equal(t, th.expectedFeeTokens, evmAddrs) // Note unsupported chain selector simply returns an empty set not an error - gasUpdates, err := pr.GetGasPriceUpdatesCreatedAfter(context.Background(), 1e6, time.Unix(0, 0), 0) + gasUpdates, err := pr.GetGasPriceUpdatesCreatedAfter(ctx, 1e6, time.Unix(0, 0), 0) require.NoError(t, err) assert.Len(t, gasUpdates, 0) @@ -188,26 +190,30 @@ func testPriceRegistryReader(t *testing.T, th priceRegReaderTH, pr ccipdata.Pric } expectedToken = append(expectedToken, th.expectedTokenUpdates[th.blockTs[j]]...) } - gasUpdates, err = pr.GetAllGasPriceUpdatesCreatedAfter(context.Background(), time.Unix(int64(ts-1), 0), 0) + if ts > math.MaxInt64 { + t.Fatalf("timestamp overflows int64: %d", ts) + } + unixTS := time.Unix(int64(ts-1), 0) //nolint:gosec // G115 false positive + gasUpdates, err = pr.GetAllGasPriceUpdatesCreatedAfter(ctx, unixTS, 0) require.NoError(t, err) assert.Len(t, gasUpdates, len(expectedGas)) - gasUpdates, err = pr.GetGasPriceUpdatesCreatedAfter(context.Background(), th.destSelectors[0], time.Unix(int64(ts-1), 0), 0) + gasUpdates, err = pr.GetGasPriceUpdatesCreatedAfter(ctx, th.destSelectors[0], unixTS, 0) require.NoError(t, err) assert.Len(t, gasUpdates, len(expectedDest0Gas)) - tokenUpdates, err2 := pr.GetTokenPriceUpdatesCreatedAfter(context.Background(), time.Unix(int64(ts-1), 0), 0) + tokenUpdates, err2 := pr.GetTokenPriceUpdatesCreatedAfter(ctx, unixTS, 0) require.NoError(t, err2) assert.Len(t, tokenUpdates, len(expectedToken)) } // Empty token set should return empty set no error. - gotEmpty, err := pr.GetTokenPrices(context.Background(), []cciptypes.Address{}) + gotEmpty, err := pr.GetTokenPrices(ctx, []cciptypes.Address{}) require.NoError(t, err) assert.Len(t, gotEmpty, 0) // We expect latest token prices to apply - allTokenUpdates, err := pr.GetTokenPriceUpdatesCreatedAfter(context.Background(), time.Unix(0, 0), 0) + allTokenUpdates, err := pr.GetTokenPriceUpdatesCreatedAfter(ctx, time.Unix(0, 0), 0) require.NoError(t, err) // Build latest map latest := make(map[cciptypes.Address]*big.Int) @@ -222,7 +228,7 @@ func testPriceRegistryReader(t *testing.T, th priceRegReaderTH, pr ccipdata.Pric latest[allTokenUpdates[i].Token] = allTokenUpdates[i].Value allTokens = append(allTokens, allTokenUpdates[i].Token) } - tokenPrices, err := pr.GetTokenPrices(context.Background(), allTokens) + tokenPrices, err := pr.GetTokenPrices(ctx, allTokens) require.NoError(t, err) require.Len(t, tokenPrices, len(allTokens)) for _, p := range tokenPrices { diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/usdc_reader.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/usdc_reader.go index cd8fd3150ae..792e2eb7253 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/usdc_reader.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/usdc_reader.go @@ -52,13 +52,11 @@ type USDCReaderImpl struct { } func (u *USDCReaderImpl) Close() error { - // FIXME Dim pgOpts removed from LogPoller return u.lp.UnregisterFilter(context.Background(), u.filter.Name) } -func (u *USDCReaderImpl) RegisterFilters() error { - // FIXME Dim pgOpts removed from LogPoller - return u.lp.RegisterFilter(context.Background(), u.filter) +func (u *USDCReaderImpl) RegisterFilters(ctx context.Context) error { + return u.lp.RegisterFilter(ctx, u.filter) } // usdcPayload has to match the onchain event emitted by the USDC message transmitter @@ -136,7 +134,7 @@ func (u *USDCReaderImpl) GetUSDCMessagePriorToLogIndexInTx(ctx context.Context, return parseUSDCMessageSent(allUsdcTokensData[usdcTokenIndex]) } -func NewUSDCReader(lggr logger.Logger, jobID string, transmitter common.Address, lp logpoller.LogPoller, registerFilters bool) (*USDCReaderImpl, error) { +func NewUSDCReader(ctx context.Context, lggr logger.Logger, jobID string, transmitter common.Address, lp logpoller.LogPoller, registerFilters bool) (*USDCReaderImpl, error) { eventSig := utils.Keccak256Fixed([]byte("MessageSent(bytes)")) r := &USDCReaderImpl{ @@ -154,15 +152,15 @@ func NewUSDCReader(lggr logger.Logger, jobID string, transmitter common.Address, } if registerFilters { - if err := r.RegisterFilters(); err != nil { + if err := r.RegisterFilters(ctx); err != nil { return nil, fmt.Errorf("register filters: %w", err) } } return r, nil } -func CloseUSDCReader(lggr logger.Logger, jobID string, transmitter common.Address, lp logpoller.LogPoller) error { - r, err := NewUSDCReader(lggr, jobID, transmitter, lp, false) +func CloseUSDCReader(ctx context.Context, lggr logger.Logger, jobID string, transmitter common.Address, lp logpoller.LogPoller) error { + r, err := NewUSDCReader(ctx, lggr, jobID, transmitter, lp, false) if err != nil { return err } diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/usdc_reader_internal_test.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/usdc_reader_internal_test.go index 953da52713b..d3df9e2124a 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/usdc_reader_internal_test.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/usdc_reader_internal_test.go @@ -1,7 +1,6 @@ package ccipdata import ( - "context" "fmt" "testing" "time" @@ -15,6 +14,7 @@ import ( "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/headtracker" @@ -35,8 +35,9 @@ func TestLogPollerClient_GetUSDCMessagePriorToLogIndexInTx(t *testing.T) { lggr := logger.Test(t) t.Run("multiple found - selected last", func(t *testing.T) { + ctx := tests.Context(t) lp := lpmocks.NewLogPoller(t) - u, _ := NewUSDCReader(lggr, "job_123", utils.RandomAddress(), lp, false) + u, _ := NewUSDCReader(ctx, lggr, "job_123", utils.RandomAddress(), lp, false) lp.On("IndexedLogsByTxHash", mock.Anything, @@ -49,15 +50,16 @@ func TestLogPollerClient_GetUSDCMessagePriorToLogIndexInTx(t *testing.T) { {LogIndex: ccipLogIndex, Data: []byte("0")}, {LogIndex: ccipLogIndex + 1, Data: []byte("1")}, }, nil) - usdcMessageData, err := u.GetUSDCMessagePriorToLogIndexInTx(context.Background(), ccipLogIndex, 0, txHash.String()) + usdcMessageData, err := u.GetUSDCMessagePriorToLogIndexInTx(ctx, ccipLogIndex, 0, txHash.String()) assert.NoError(t, err) assert.Equal(t, expectedPostParse, hexutil.Encode(usdcMessageData)) lp.AssertExpectations(t) }) t.Run("multiple found - selected first", func(t *testing.T) { + ctx := tests.Context(t) lp := lpmocks.NewLogPoller(t) - u, _ := NewUSDCReader(lggr, "job_123", utils.RandomAddress(), lp, false) + u, _ := NewUSDCReader(ctx, lggr, "job_123", utils.RandomAddress(), lp, false) lp.On("IndexedLogsByTxHash", mock.Anything, @@ -70,15 +72,16 @@ func TestLogPollerClient_GetUSDCMessagePriorToLogIndexInTx(t *testing.T) { {LogIndex: ccipLogIndex, Data: []byte("0")}, {LogIndex: ccipLogIndex + 1, Data: []byte("1")}, }, nil) - usdcMessageData, err := u.GetUSDCMessagePriorToLogIndexInTx(context.Background(), ccipLogIndex, 1, txHash.String()) + usdcMessageData, err := u.GetUSDCMessagePriorToLogIndexInTx(ctx, ccipLogIndex, 1, txHash.String()) assert.NoError(t, err) assert.Equal(t, expectedPostParse, hexutil.Encode(usdcMessageData)) lp.AssertExpectations(t) }) t.Run("logs fetched from memory in subsequent calls", func(t *testing.T) { + ctx := tests.Context(t) lp := lpmocks.NewLogPoller(t) - u, _ := NewUSDCReader(lggr, "job_123", utils.RandomAddress(), lp, false) + u, _ := NewUSDCReader(ctx, lggr, "job_123", utils.RandomAddress(), lp, false) lp.On("IndexedLogsByTxHash", mock.Anything, @@ -93,12 +96,12 @@ func TestLogPollerClient_GetUSDCMessagePriorToLogIndexInTx(t *testing.T) { }, nil).Once() // first call logs must be fetched from lp - usdcMessageData, err := u.GetUSDCMessagePriorToLogIndexInTx(context.Background(), ccipLogIndex, 1, txHash.String()) + usdcMessageData, err := u.GetUSDCMessagePriorToLogIndexInTx(ctx, ccipLogIndex, 1, txHash.String()) assert.NoError(t, err) assert.Equal(t, expectedPostParse, hexutil.Encode(usdcMessageData)) // subsequent call, logs must be fetched from memory - usdcMessageData, err = u.GetUSDCMessagePriorToLogIndexInTx(context.Background(), ccipLogIndex, 1, txHash.String()) + usdcMessageData, err = u.GetUSDCMessagePriorToLogIndexInTx(ctx, ccipLogIndex, 1, txHash.String()) assert.NoError(t, err) assert.Equal(t, expectedPostParse, hexutil.Encode(usdcMessageData)) @@ -106,8 +109,9 @@ func TestLogPollerClient_GetUSDCMessagePriorToLogIndexInTx(t *testing.T) { }) t.Run("none found", func(t *testing.T) { + ctx := tests.Context(t) lp := lpmocks.NewLogPoller(t) - u, _ := NewUSDCReader(lggr, "job_123", utils.RandomAddress(), lp, false) + u, _ := NewUSDCReader(ctx, lggr, "job_123", utils.RandomAddress(), lp, false) lp.On("IndexedLogsByTxHash", mock.Anything, u.usdcMessageSent, @@ -115,7 +119,7 @@ func TestLogPollerClient_GetUSDCMessagePriorToLogIndexInTx(t *testing.T) { txHash, ).Return([]logpoller.Log{}, nil) - usdcMessageData, err := u.GetUSDCMessagePriorToLogIndexInTx(context.Background(), ccipLogIndex, 0, txHash.String()) + usdcMessageData, err := u.GetUSDCMessagePriorToLogIndexInTx(ctx, ccipLogIndex, 0, txHash.String()) assert.Errorf(t, err, fmt.Sprintf("no USDC message found prior to log index %d in tx %s", ccipLogIndex, txHash.Hex())) assert.Nil(t, usdcMessageData) @@ -137,6 +141,7 @@ func TestParse(t *testing.T) { func TestFilters(t *testing.T) { t.Run("filters of different jobs should be distinct", func(t *testing.T) { + ctx := tests.Context(t) lggr := logger.Test(t) chainID := testutils.NewRandomEVMChainID() db := pgtest.NewSqlxDB(t) @@ -163,15 +168,15 @@ func TestFilters(t *testing.T) { f1 := logpoller.FilterName("USDC message sent", jobID1, transmitter.Hex()) f2 := logpoller.FilterName("USDC message sent", jobID2, transmitter.Hex()) - _, err := NewUSDCReader(lggr, jobID1, transmitter, lp, true) + _, err := NewUSDCReader(ctx, lggr, jobID1, transmitter, lp, true) assert.NoError(t, err) assert.True(t, lp.HasFilter(f1)) - _, err = NewUSDCReader(lggr, jobID2, transmitter, lp, true) + _, err = NewUSDCReader(ctx, lggr, jobID2, transmitter, lp, true) assert.NoError(t, err) assert.True(t, lp.HasFilter(f2)) - err = CloseUSDCReader(lggr, jobID2, transmitter, lp) + err = CloseUSDCReader(ctx, lggr, jobID2, transmitter, lp) assert.NoError(t, err) assert.True(t, lp.HasFilter(f1)) assert.False(t, lp.HasFilter(f2)) diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_2_0/commit_store.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_2_0/commit_store.go index 29076e6cd74..2d772e3bd0a 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_2_0/commit_store.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_2_0/commit_store.go @@ -277,7 +277,7 @@ func (c *CommitStore) ChangeConfig(_ context.Context, onchainConfig []byte, offc } func (c *CommitStore) Close() error { - return logpollerutil.UnregisterLpFilters(c.lp, c.filters) + return logpollerutil.UnregisterLpFilters(context.Background(), c.lp, c.filters) } func (c *CommitStore) parseReport(log types.Log) (*cciptypes.CommitStoreReport, error) { @@ -429,8 +429,8 @@ func (c *CommitStore) VerifyExecutionReport(ctx context.Context, report cciptype return true, nil } -func (c *CommitStore) RegisterFilters() error { - return logpollerutil.RegisterLpFilters(c.lp, c.filters) +func (c *CommitStore) RegisterFilters(ctx context.Context) error { + return logpollerutil.RegisterLpFilters(ctx, c.lp, c.filters) } func NewCommitStore(lggr logger.Logger, addr common.Address, ec client.Client, lp logpoller.LogPoller) (*CommitStore, error) { diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_2_0/offramp.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_2_0/offramp.go index f2887688965..e8017016690 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_2_0/offramp.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_2_0/offramp.go @@ -428,10 +428,10 @@ func (o *OffRamp) ChangeConfig(ctx context.Context, onchainConfigBytes []byte, o } func (o *OffRamp) Close() error { - return logpollerutil.UnregisterLpFilters(o.lp, o.filters) + return logpollerutil.UnregisterLpFilters(context.Background(), o.lp, o.filters) } -func (o *OffRamp) RegisterFilters() error { - return logpollerutil.RegisterLpFilters(o.lp, o.filters) +func (o *OffRamp) RegisterFilters(ctx context.Context) error { + return logpollerutil.RegisterLpFilters(ctx, o.lp, o.filters) } func (o *OffRamp) GetExecutionState(ctx context.Context, sequenceNumber uint64) (uint8, error) { diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_2_0/onramp.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_2_0/onramp.go index 071e8a8e03e..52f241a30a6 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_2_0/onramp.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_2_0/onramp.go @@ -213,11 +213,11 @@ func (o *OnRamp) IsSourceCursed(ctx context.Context) (bool, error) { } func (o *OnRamp) Close() error { - return logpollerutil.UnregisterLpFilters(o.lp, o.filters) + return logpollerutil.UnregisterLpFilters(context.Background(), o.lp, o.filters) } -func (o *OnRamp) RegisterFilters() error { - return logpollerutil.RegisterLpFilters(o.lp, o.filters) +func (o *OnRamp) RegisterFilters(ctx context.Context) error { + return logpollerutil.RegisterLpFilters(ctx, o.lp, o.filters) } func (o *OnRamp) logToMessage(log types.Log) (*cciptypes.EVM2EVMMessage, error) { diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_2_0/price_registry.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_2_0/price_registry.go index 4c4058922dc..636b37c9100 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_2_0/price_registry.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_2_0/price_registry.go @@ -50,7 +50,7 @@ type PriceRegistry struct { tokenDecimalsCache sync.Map } -func NewPriceRegistry(lggr logger.Logger, priceRegistryAddr common.Address, lp logpoller.LogPoller, ec client.Client, registerFilters bool) (*PriceRegistry, error) { +func NewPriceRegistry(ctx context.Context, lggr logger.Logger, priceRegistryAddr common.Address, lp logpoller.LogPoller, ec client.Client, registerFilters bool) (*PriceRegistry, error) { priceRegistry, err := price_registry_1_2_0.NewPriceRegistry(priceRegistryAddr, ec) if err != nil { return nil, err @@ -79,7 +79,7 @@ func NewPriceRegistry(lggr logger.Logger, priceRegistryAddr common.Address, lp l Retention: ccipdata.CacheEvictionLogsRetention, }} if registerFilters { - err = logpollerutil.RegisterLpFilters(lp, filters) + err = logpollerutil.RegisterLpFilters(ctx, lp, filters) if err != nil { return nil, err } @@ -151,7 +151,7 @@ func (p *PriceRegistry) GetFeeTokens(ctx context.Context) ([]cciptypes.Address, } func (p *PriceRegistry) Close() error { - return logpollerutil.UnregisterLpFilters(p.lp, p.filters) + return logpollerutil.UnregisterLpFilters(context.Background(), p.lp, p.filters) } func (p *PriceRegistry) GetTokenPriceUpdatesCreatedAfter(ctx context.Context, ts time.Time, confs int) ([]cciptypes.TokenPriceUpdateWithTxMeta, error) { diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_5_0/onramp.go b/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_5_0/onramp.go index ad540ffd648..da41d116bc8 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_5_0/onramp.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdata/v1_5_0/onramp.go @@ -216,11 +216,11 @@ func (o *OnRamp) IsSourceCursed(ctx context.Context) (bool, error) { } func (o *OnRamp) Close() error { - return logpollerutil.UnregisterLpFilters(o.lp, o.filters) + return logpollerutil.UnregisterLpFilters(context.Background(), o.lp, o.filters) } -func (o *OnRamp) RegisterFilters() error { - return logpollerutil.RegisterLpFilters(o.lp, o.filters) +func (o *OnRamp) RegisterFilters(ctx context.Context) error { + return logpollerutil.RegisterLpFilters(ctx, o.lp, o.filters) } func (o *OnRamp) logToMessage(log types.Log) (*cciptypes.EVM2EVMMessage, error) { diff --git a/core/services/ocr2/plugins/ccip/internal/ccipdb/price_service.go b/core/services/ocr2/plugins/ccip/internal/ccipdb/price_service.go index e8b9a4de721..b5e8853d67c 100644 --- a/core/services/ocr2/plugins/ccip/internal/ccipdb/price_service.go +++ b/core/services/ocr2/plugins/ccip/internal/ccipdb/price_service.go @@ -68,10 +68,9 @@ type priceService struct { destPriceRegistryReader ccipdata.PriceRegistryReader services.StateMachine - wg *sync.WaitGroup - backgroundCtx context.Context //nolint:containedctx - backgroundCancel context.CancelFunc - dynamicConfigMu *sync.RWMutex + wg sync.WaitGroup + stopChan services.StopChan + dynamicConfigMu sync.RWMutex } func NewPriceService( @@ -85,8 +84,6 @@ func NewPriceService( priceGetter pricegetter.AllTokensPriceGetter, offRampReader ccipdata.OffRampReader, ) PriceService { - ctx, cancel := context.WithCancel(context.Background()) - pw := &priceService{ gasUpdateInterval: gasPriceUpdateInterval, tokenUpdateInterval: tokenPriceUpdateInterval, @@ -100,11 +97,7 @@ func NewPriceService( sourceNative: sourceNative, priceGetter: priceGetter, offRampReader: offRampReader, - - wg: new(sync.WaitGroup), - backgroundCtx: ctx, - backgroundCancel: cancel, - dynamicConfigMu: &sync.RWMutex{}, + stopChan: make(services.StopChan), } return pw } @@ -121,13 +114,16 @@ func (p *priceService) Start(context.Context) error { func (p *priceService) Close() error { return p.StateMachine.StopOnce("PriceService", func() error { p.lggr.Info("Closing PriceService") - p.backgroundCancel() + close(p.stopChan) p.wg.Wait() return nil }) } func (p *priceService) run() { + ctx, cancel := p.stopChan.NewCtx() + defer cancel() + gasUpdateTicker := time.NewTicker(utils.WithJitter(p.gasUpdateInterval)) tokenUpdateTicker := time.NewTicker(utils.WithJitter(p.tokenUpdateInterval)) @@ -138,15 +134,15 @@ func (p *priceService) run() { for { select { - case <-p.backgroundCtx.Done(): + case <-ctx.Done(): return case <-gasUpdateTicker.C: - err := p.runGasPriceUpdate(p.backgroundCtx) + err := p.runGasPriceUpdate(ctx) if err != nil { p.lggr.Errorw("Error when updating gas prices in the background", "err", err) } case <-tokenUpdateTicker.C: - err := p.runTokenPriceUpdate(p.backgroundCtx) + err := p.runTokenPriceUpdate(ctx) if err != nil { p.lggr.Errorw("Error when updating token prices in the background", "err", err) } diff --git a/core/services/ocr2/plugins/ccip/internal/logpollerutil/filters.go b/core/services/ocr2/plugins/ccip/internal/logpollerutil/filters.go index e42dd8c154d..de185611641 100644 --- a/core/services/ocr2/plugins/ccip/internal/logpollerutil/filters.go +++ b/core/services/ocr2/plugins/ccip/internal/logpollerutil/filters.go @@ -9,26 +9,24 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils" ) -func RegisterLpFilters(lp logpoller.LogPoller, filters []logpoller.Filter) error { +func RegisterLpFilters(ctx context.Context, lp logpoller.LogPoller, filters []logpoller.Filter) error { for _, lpFilter := range filters { if filterContainsZeroAddress(lpFilter.Addresses) { continue } - // FIXME Dim pgOpts removed from LogPoller - if err := lp.RegisterFilter(context.Background(), lpFilter); err != nil { + if err := lp.RegisterFilter(ctx, lpFilter); err != nil { return err } } return nil } -func UnregisterLpFilters(lp logpoller.LogPoller, filters []logpoller.Filter) error { +func UnregisterLpFilters(ctx context.Context, lp logpoller.LogPoller, filters []logpoller.Filter) error { for _, lpFilter := range filters { if filterContainsZeroAddress(lpFilter.Addresses) { continue } - // FIXME Dim pgOpts removed from LogPoller - if err := lp.UnregisterFilter(context.Background(), lpFilter.Name); err != nil { + if err := lp.UnregisterFilter(ctx, lpFilter.Name); err != nil { return err } } diff --git a/core/services/ocr2/plugins/ccip/internal/oraclelib/backfilled_oracle.go b/core/services/ocr2/plugins/ccip/internal/oraclelib/backfilled_oracle.go index d2851e3a079..053cddabcd9 100644 --- a/core/services/ocr2/plugins/ccip/internal/oraclelib/backfilled_oracle.go +++ b/core/services/ocr2/plugins/ccip/internal/oraclelib/backfilled_oracle.go @@ -6,128 +6,23 @@ import ( "sync/atomic" "time" + commonservices "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink/v2/core/services" "go.uber.org/multierr" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/job" ) -type BackfilledOracle struct { - srcStartBlock, dstStartBlock uint64 - oracleStarted atomic.Bool - cancelFn context.CancelFunc - src, dst logpoller.LogPoller - oracle job.ServiceCtx - lggr logger.Logger -} - -func NewBackfilledOracle(lggr logger.Logger, src, dst logpoller.LogPoller, srcStartBlock, dstStartBlock uint64, oracle job.ServiceCtx) *BackfilledOracle { - return &BackfilledOracle{ - srcStartBlock: srcStartBlock, - dstStartBlock: dstStartBlock, - oracleStarted: atomic.Bool{}, - cancelFn: nil, - src: src, - dst: dst, - oracle: oracle, - lggr: lggr, - } -} - -func (r *BackfilledOracle) Start(_ context.Context) error { - go r.Run() - return nil -} - -func (r *BackfilledOracle) IsRunning() bool { - return r.oracleStarted.Load() -} - -func (r *BackfilledOracle) Run() { - ctx, cancelFn := context.WithCancel(context.Background()) - r.cancelFn = cancelFn - var err error - var errMu sync.Mutex - var wg sync.WaitGroup - // Replay in parallel if both requested. - if r.srcStartBlock != 0 { - wg.Add(1) - go func() { - defer wg.Done() - s := time.Now() - r.lggr.Infow("start replaying src chain", "fromBlock", r.srcStartBlock) - srcReplayErr := r.src.Replay(ctx, int64(r.srcStartBlock)) - errMu.Lock() - err = multierr.Combine(err, srcReplayErr) - errMu.Unlock() - r.lggr.Infow("finished replaying src chain", "time", time.Since(s)) - }() - } - if r.dstStartBlock != 0 { - wg.Add(1) - go func() { - defer wg.Done() - s := time.Now() - r.lggr.Infow("start replaying dst chain", "fromBlock", r.dstStartBlock) - dstReplayErr := r.dst.Replay(ctx, int64(r.dstStartBlock)) - errMu.Lock() - err = multierr.Combine(err, dstReplayErr) - errMu.Unlock() - r.lggr.Infow("finished replaying dst chain", "time", time.Since(s)) - }() - } - wg.Wait() - if err != nil { - r.lggr.Criticalw("unexpected error replaying, continuing plugin boot without all the logs backfilled", "err", err) - } - if err := ctx.Err(); err != nil { - r.lggr.Errorw("context already cancelled", "err", err) - return - } - // Start oracle with all logs present from dstStartBlock on dst and - // all logs from srcStartBlock on src. - if err := r.oracle.Start(ctx); err != nil { - // Should never happen. - r.lggr.Errorw("unexpected error starting oracle", "err", err) - } else { - r.oracleStarted.Store(true) - } -} - -func (r *BackfilledOracle) Close() error { - if r.oracleStarted.Load() { - // If the oracle is running, it must be Closed/stopped - if err := r.oracle.Close(); err != nil { - r.lggr.Errorw("unexpected error stopping oracle", "err", err) - return err - } - // Flag the oracle as closed with our internal variable that keeps track - // of its state. This will allow to re-start the process - r.oracleStarted.Store(false) - } - if r.cancelFn != nil { - // This is useful to step the previous tasks that are spawned in - // parallel before starting the Oracle. This will use the context to - // signal them to exit immediately. - // - // It can be possible this is the only way to stop the Start() async - // flow, specially when the previusly task are running (the replays) and - // `oracleStarted` would be false in that example. Calling `cancelFn()` - // will stop the replays and will prevent the oracle to start - r.cancelFn() - } - return nil -} - func NewChainAgnosticBackFilledOracle(lggr logger.Logger, srcProvider services.ServiceCtx, dstProvider services.ServiceCtx, oracle job.ServiceCtx) *ChainAgnosticBackFilledOracle { return &ChainAgnosticBackFilledOracle{ srcProvider: srcProvider, dstProvider: dstProvider, oracle: oracle, lggr: lggr, + stopCh: make(chan struct{}), + done: make(chan struct{}), } } @@ -137,7 +32,8 @@ type ChainAgnosticBackFilledOracle struct { oracle job.ServiceCtx lggr logger.Logger oracleStarted atomic.Bool - cancelFn context.CancelFunc + stopCh commonservices.StopChan + done chan struct{} } func (r *ChainAgnosticBackFilledOracle) Start(_ context.Context) error { @@ -146,8 +42,10 @@ func (r *ChainAgnosticBackFilledOracle) Start(_ context.Context) error { } func (r *ChainAgnosticBackFilledOracle) run() { - ctx, cancelFn := context.WithCancel(context.Background()) - r.cancelFn = cancelFn + defer close(r.done) + ctx, cancel := r.stopCh.NewCtx() + defer cancel() + var err error var errMu sync.Mutex var wg sync.WaitGroup @@ -192,6 +90,8 @@ func (r *ChainAgnosticBackFilledOracle) run() { } func (r *ChainAgnosticBackFilledOracle) Close() error { + close(r.stopCh) + <-r.done if r.oracleStarted.Load() { // If the oracle is running, it must be Closed/stopped // TODO: Close should be safe to call in either case? @@ -203,16 +103,5 @@ func (r *ChainAgnosticBackFilledOracle) Close() error { // of its state. This will allow to re-start the process r.oracleStarted.Store(false) } - if r.cancelFn != nil { - // This is useful to step the previous tasks that are spawned in - // parallel before starting the Oracle. This will use the context to - // signal them to exit immediately. - // - // It can be possible this is the only way to stop the Start() async - // flow, specially when the previusly task are running (the replays) and - // `oracleStarted` would be false in that example. Calling `cancelFn()` - // will stop the replays and will prevent the oracle to start - r.cancelFn() - } return nil } diff --git a/core/services/ocr2/plugins/ccip/internal/oraclelib/backfilled_oracle_test.go b/core/services/ocr2/plugins/ccip/internal/oraclelib/backfilled_oracle_test.go deleted file mode 100644 index 6db1ebbadd9..00000000000 --- a/core/services/ocr2/plugins/ccip/internal/oraclelib/backfilled_oracle_test.go +++ /dev/null @@ -1,56 +0,0 @@ -package oraclelib - -import ( - "testing" - - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - - lpmocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller/mocks" - "github.com/smartcontractkit/chainlink/v2/core/logger" - jobmocks "github.com/smartcontractkit/chainlink/v2/core/services/job/mocks" -) - -func TestBackfilledOracle(t *testing.T) { - // First scenario: Start() fails, check that all Replay are being called. - lp1 := lpmocks.NewLogPoller(t) - lp2 := lpmocks.NewLogPoller(t) - lp1.On("Replay", mock.Anything, int64(1)).Return(nil) - lp2.On("Replay", mock.Anything, int64(2)).Return(nil) - oracle1 := jobmocks.NewServiceCtx(t) - oracle1.On("Start", mock.Anything).Return(errors.New("Failed to start")).Twice() - job := NewBackfilledOracle(logger.TestLogger(t), lp1, lp2, 1, 2, oracle1) - - job.Run() - assert.False(t, job.IsRunning()) - job.Run() - assert.False(t, job.IsRunning()) - - /// Start -> Stop -> Start - oracle2 := jobmocks.NewServiceCtx(t) - oracle2.On("Start", mock.Anything).Return(nil).Twice() - oracle2.On("Close").Return(nil).Once() - - job2 := NewBackfilledOracle(logger.TestLogger(t), lp1, lp2, 1, 2, oracle2) - job2.Run() - assert.True(t, job2.IsRunning()) - assert.Nil(t, job2.Close()) - assert.False(t, job2.IsRunning()) - assert.Nil(t, job2.Close()) - assert.False(t, job2.IsRunning()) - job2.Run() - assert.True(t, job2.IsRunning()) - - /// Replay fails, but it starts anyway - lp11 := lpmocks.NewLogPoller(t) - lp12 := lpmocks.NewLogPoller(t) - lp11.On("Replay", mock.Anything, int64(1)).Return(errors.New("Replay failed")).Once() - lp12.On("Replay", mock.Anything, int64(2)).Return(errors.New("Replay failed")).Once() - - oracle := jobmocks.NewServiceCtx(t) - oracle.On("Start", mock.Anything).Return(nil).Once() - job3 := NewBackfilledOracle(logger.NullLogger, lp11, lp12, 1, 2, oracle) - job3.Run() - assert.True(t, job3.IsRunning()) -} diff --git a/core/services/ocr2/plugins/ccip/internal/pricegetter/pipeline_test.go b/core/services/ocr2/plugins/ccip/internal/pricegetter/pipeline_test.go index 8aeeff96b57..e71d5402503 100644 --- a/core/services/ocr2/plugins/ccip/internal/pricegetter/pipeline_test.go +++ b/core/services/ocr2/plugins/ccip/internal/pricegetter/pipeline_test.go @@ -1,7 +1,6 @@ package pricegetter_test import ( - "context" "fmt" "math/big" "net/http" @@ -17,6 +16,7 @@ import ( config2 "github.com/smartcontractkit/chainlink-common/pkg/config" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccip" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/bridges" "github.com/smartcontractkit/chainlink/v2/core/logger" @@ -31,6 +31,7 @@ import ( ) func TestDataSource(t *testing.T) { + ctx := testutils.Context(t) linkEth := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, err := w.Write([]byte(`{"JuelsPerETH": "200000000000000000000"}`)) require.NoError(t, err) @@ -58,7 +59,7 @@ func TestDataSource(t *testing.T) { priceGetter := newTestPipelineGetter(t, source) // Ask for all prices present in spec. - prices, err := priceGetter.GetJobSpecTokenPricesUSD(context.Background()) + prices, err := priceGetter.GetJobSpecTokenPricesUSD(ctx) require.NoError(t, err) assert.Equal(t, prices, map[cciptypes.Address]*big.Int{ linkTokenAddress: big.NewInt(0).Mul(big.NewInt(200), big.NewInt(1000000000000000000)), @@ -66,7 +67,7 @@ func TestDataSource(t *testing.T) { }) // Specifically ask for all prices - pricesWithInput, errWithInput := priceGetter.TokenPricesUSD(context.Background(), []cciptypes.Address{ + pricesWithInput, errWithInput := priceGetter.TokenPricesUSD(ctx, []cciptypes.Address{ linkTokenAddress, usdcTokenAddress, }) @@ -77,13 +78,13 @@ func TestDataSource(t *testing.T) { }) // Ask a non-existent price. - _, err = priceGetter.TokenPricesUSD(context.Background(), []cciptypes.Address{ + _, err = priceGetter.TokenPricesUSD(ctx, []cciptypes.Address{ ccipcalc.HexToAddress("0x1591690b8638f5fb2dbec82ac741805ac5da8b45dc5263f4875b0496fdce4e11"), }) require.Error(t, err) // Ask only one price - prices, err = priceGetter.TokenPricesUSD(context.Background(), []cciptypes.Address{linkTokenAddress}) + prices, err = priceGetter.TokenPricesUSD(ctx, []cciptypes.Address{linkTokenAddress}) require.NoError(t, err) assert.Equal(t, prices, map[cciptypes.Address]*big.Int{ linkTokenAddress: big.NewInt(0).Mul(big.NewInt(200), big.NewInt(1000000000000000000)), @@ -135,6 +136,7 @@ func TestParsingDifferentFormats(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + ctx := testutils.Context(t) token := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, err := fmt.Fprintf(w, `{"MyCoin": %s}`, tt.inputValue) require.NoError(t, err) @@ -151,7 +153,7 @@ func TestParsingDifferentFormats(t *testing.T) { `, token.URL, strings.ToLower(address.String())) prices, err := newTestPipelineGetter(t, source). - TokenPricesUSD(context.Background(), []cciptypes.Address{ccipcalc.EvmAddrToGeneric(address)}) + TokenPricesUSD(ctx, []cciptypes.Address{ccipcalc.EvmAddrToGeneric(address)}) if tt.expectedError { require.Error(t, err) diff --git a/core/services/ocr2/plugins/ccip/testhelpers/ccip_contracts.go b/core/services/ocr2/plugins/ccip/testhelpers/ccip_contracts.go index 8410e6ff938..d69be750253 100644 --- a/core/services/ocr2/plugins/ccip/testhelpers/ccip_contracts.go +++ b/core/services/ocr2/plugins/ccip/testhelpers/ccip_contracts.go @@ -22,6 +22,7 @@ import ( ocr2types "github.com/smartcontractkit/libocr/offchainreporting2plus/types" "github.com/smartcontractkit/chainlink-common/pkg/config" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink-common/pkg/hashutil" "github.com/smartcontractkit/chainlink-common/pkg/merklemulti" @@ -438,7 +439,7 @@ func (c *CCIPContracts) SetNopsOnRamp(t *testing.T, nopsAndWeights []evm_2_evm_o tx, err := c.Source.OnRamp.SetNops(c.Source.User, nopsAndWeights) require.NoError(t, err) c.Source.Chain.Commit() - _, err = bind.WaitMined(context.Background(), c.Source.Chain, tx) + _, err = bind.WaitMined(tests.Context(t), c.Source.Chain, tx) require.NoError(t, err) } @@ -578,7 +579,7 @@ func (c *CCIPContracts) SetupExecOCR2Config(t *testing.T, execOnchainConfig, exe func (c *CCIPContracts) SetupOnchainConfig(t *testing.T, commitOnchainConfig, commitOffchainConfig, execOnchainConfig, execOffchainConfig []byte) int64 { // Note We do NOT set the payees, payment is done in the OCR2Base implementation - blockBeforeConfig, err := c.Dest.Chain.BlockByNumber(context.Background(), nil) + blockBeforeConfig, err := c.Dest.Chain.BlockByNumber(tests.Context(t), nil) require.NoError(t, err) c.SetupCommitOCR2Config(t, commitOnchainConfig, commitOffchainConfig) @@ -1292,8 +1293,8 @@ type ManualExecArgs struct { // if the block located has a timestamp greater than the timestamp of mentioned source block // it just returns the first block found with lesser timestamp of the source block // providing a value of args.DestDeployedAt ensures better performance by reducing the range of block numbers to be traversed -func (args *ManualExecArgs) ApproxDestStartBlock() error { - sourceBlockHdr, err := args.SourceChain.HeaderByNumber(context.Background(), args.SourceStartBlock) +func (args *ManualExecArgs) ApproxDestStartBlock(ctx context.Context) error { + sourceBlockHdr, err := args.SourceChain.HeaderByNumber(ctx, args.SourceStartBlock) if err != nil { return err } @@ -1303,7 +1304,7 @@ func (args *ManualExecArgs) ApproxDestStartBlock() error { minBlockNum := args.DestDeployedAt closestBlockNum := uint64(math.Floor((float64(maxBlockNum) + float64(minBlockNum)) / 2)) var closestBlockHdr *types.Header - closestBlockHdr, err = args.DestChain.HeaderByNumber(context.Background(), big.NewInt(int64(closestBlockNum))) + closestBlockHdr, err = args.DestChain.HeaderByNumber(ctx, new(big.Int).SetUint64(closestBlockNum)) if err != nil { return err } @@ -1324,7 +1325,7 @@ func (args *ManualExecArgs) ApproxDestStartBlock() error { minBlockNum = blockNum + 1 } closestBlockNum = uint64(math.Floor((float64(maxBlockNum) + float64(minBlockNum)) / 2)) - closestBlockHdr, err = args.DestChain.HeaderByNumber(context.Background(), big.NewInt(int64(closestBlockNum))) + closestBlockHdr, err = args.DestChain.HeaderByNumber(ctx, new(big.Int).SetUint64(closestBlockNum)) if err != nil { return err } @@ -1335,7 +1336,7 @@ func (args *ManualExecArgs) ApproxDestStartBlock() error { if closestBlockNum <= 0 { return fmt.Errorf("approx destination blocknumber not found") } - closestBlockHdr, err = args.DestChain.HeaderByNumber(context.Background(), big.NewInt(int64(closestBlockNum))) + closestBlockHdr, err = args.DestChain.HeaderByNumber(ctx, new(big.Int).SetUint64(closestBlockNum)) if err != nil { return err } @@ -1371,7 +1372,7 @@ func (args *ManualExecArgs) FindSeqNrFromCCIPSendRequested() (uint64, error) { return seqNr, nil } -func (args *ManualExecArgs) ExecuteManually() (*types.Transaction, error) { +func (args *ManualExecArgs) ExecuteManually(ctx context.Context) (*types.Transaction, error) { if args.SourceChainID == 0 || args.DestChainID == 0 || args.DestUser == nil { @@ -1404,7 +1405,7 @@ func (args *ManualExecArgs) ExecuteManually() (*types.Transaction, error) { return nil, err } if args.DestStartBlock < 1 { - err = args.ApproxDestStartBlock() + err = args.ApproxDestStartBlock(ctx) if err != nil { return nil, err } @@ -1571,7 +1572,7 @@ func (c *CCIPContracts) ExecuteMessage( destStartBlock uint64, ) uint64 { t.Log("Executing request manually") - sendReqReceipt, err := c.Source.Chain.TransactionReceipt(context.Background(), txHash) + sendReqReceipt, err := c.Source.Chain.TransactionReceipt(tests.Context(t), txHash) require.NoError(t, err) args := ManualExecArgs{ SourceChainID: c.Source.ChainID, @@ -1588,11 +1589,12 @@ func (c *CCIPContracts) ExecuteMessage( OnRamp: c.Source.OnRamp.Address().String(), OffRamp: c.Dest.OffRamp.Address().String(), } - tx, err := args.ExecuteManually() + ctx := tests.Context(t) + tx, err := args.ExecuteManually(ctx) require.NoError(t, err) c.Dest.Chain.Commit() c.Source.Chain.Commit() - rec, err := c.Dest.Chain.TransactionReceipt(context.Background(), tx.Hash()) + rec, err := c.Dest.Chain.TransactionReceipt(ctx, tx.Hash()) require.NoError(t, err) require.Equal(t, uint64(1), rec.Status, "manual execution failed") t.Logf("Manual Execution completed for seqNum %d", args.SeqNr) diff --git a/core/services/ocr2/plugins/ccip/testhelpers/integration/chainlink.go b/core/services/ocr2/plugins/ccip/testhelpers/integration/chainlink.go index d0d502e8673..0b7f0de4d25 100644 --- a/core/services/ocr2/plugins/ccip/testhelpers/integration/chainlink.go +++ b/core/services/ocr2/plugins/ccip/testhelpers/integration/chainlink.go @@ -35,6 +35,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/loop" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccip" coretypes "github.com/smartcontractkit/chainlink-common/pkg/types/core/mocks" @@ -342,7 +343,7 @@ func (node *Node) AddJob(t *testing.T, spec *OCR2TaskJobSpec) { nil, ) require.NoError(t, err) - err = node.App.AddJobV2(context.Background(), &ccipJob) + err = node.App.AddJobV2(tests.Context(t), &ccipJob) require.NoError(t, err) } @@ -351,7 +352,7 @@ func (node *Node) AddBootstrapJob(t *testing.T, spec *OCR2TaskJobSpec) { require.NoError(t, err) ccipJob, err := ocrbootstrap.ValidatedBootstrapSpecToml(specString) require.NoError(t, err) - err = node.App.AddJobV2(context.Background(), &ccipJob) + err = node.App.AddJobV2(tests.Context(t), &ccipJob) require.NoError(t, err) } @@ -512,13 +513,13 @@ func setupNodeCCIP( lggr.Debug(fmt.Sprintf("Transmitter address %s chainID %s", transmitter, s.EVMChainID.String())) // Fund the commitTransmitter address with some ETH - n, err := destChain.NonceAt(context.Background(), owner.From, nil) + n, err := destChain.NonceAt(tests.Context(t), owner.From, nil) require.NoError(t, err) tx := types3.NewTransaction(n, transmitter, big.NewInt(1000000000000000000), 21000, big.NewInt(1000000000), nil) signedTx, err := owner.Signer(owner.From, tx) require.NoError(t, err) - err = destChain.SendTransaction(context.Background(), signedTx) + err = destChain.SendTransaction(tests.Context(t), signedTx) require.NoError(t, err) destChain.Commit() @@ -998,7 +999,7 @@ func (c *CCIPIntegrationTestHarness) SetupAndStartNodes(ctx context.Context, t * func (c *CCIPIntegrationTestHarness) SetUpNodesAndJobs(t *testing.T, pricePipeline string, priceGetterConfig string, usdcAttestationAPI string) CCIPJobSpecParams { // setup Jobs - ctx := context.Background() + ctx := tests.Context(t) // Starts nodes and configures them in the OCR contracts. bootstrapNode, _, configBlock := c.SetupAndStartNodes(ctx, t, int64(freeport.GetOne(t))) @@ -1011,7 +1012,7 @@ func (c *CCIPIntegrationTestHarness) SetUpNodesAndJobs(t *testing.T, pricePipeli // Replay for bootstrap. bc, err := bootstrapNode.App.GetRelayers().LegacyEVMChains().Get(strconv.FormatUint(c.Dest.ChainID, 10)) require.NoError(t, err) - require.NoError(t, bc.LogPoller().Replay(context.Background(), configBlock)) + require.NoError(t, bc.LogPoller().Replay(ctx, configBlock)) c.Dest.Chain.Commit() return jobParams diff --git a/core/services/ocr2/plugins/ccip/testhelpers/simulated_backend.go b/core/services/ocr2/plugins/ccip/testhelpers/simulated_backend.go index ea91362aaae..f48027545ad 100644 --- a/core/services/ocr2/plugins/ccip/testhelpers/simulated_backend.go +++ b/core/services/ocr2/plugins/ccip/testhelpers/simulated_backend.go @@ -1,20 +1,19 @@ package testhelpers import ( - "context" "math/big" "testing" "time" "github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/accounts/abi/bind/backends" - "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" ethtypes "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/eth/ethconfig" "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink/v2/core/services/keystore" ) @@ -54,21 +53,13 @@ func (ks EthKeyStoreSim) Eth() keystore.Eth { return ks.ETHKS } -func (ks EthKeyStoreSim) SignTx(address common.Address, tx *ethtypes.Transaction, chainID *big.Int) (*ethtypes.Transaction, error) { - if chainID.String() == "1000" { - // A terrible hack, just for the multichain test. All simulation clients run on chainID 1337. - // We let the DestChainSelector actually use 1337 to make sure the offchainConfig digests are properly generated. - return ks.ETHKS.SignTx(context.Background(), address, tx, big.NewInt(1337)) - } - return ks.ETHKS.SignTx(context.Background(), address, tx, chainID) -} - var _ keystore.Eth = EthKeyStoreSim{}.ETHKS func ConfirmTxs(t *testing.T, txs []*ethtypes.Transaction, chain *backends.SimulatedBackend) { chain.Commit() + ctx := tests.Context(t) for _, tx := range txs { - rec, err := bind.WaitMined(context.Background(), chain, tx) + rec, err := bind.WaitMined(ctx, chain, tx) require.NoError(t, err) require.Equal(t, uint64(1), rec.Status) } diff --git a/core/services/ocr2/plugins/ccip/testhelpers/testhelpers_1_4_0/ccip_contracts_1_4_0.go b/core/services/ocr2/plugins/ccip/testhelpers/testhelpers_1_4_0/ccip_contracts_1_4_0.go index 64d3b5d26c1..ccdc93660c2 100644 --- a/core/services/ocr2/plugins/ccip/testhelpers/testhelpers_1_4_0/ccip_contracts_1_4_0.go +++ b/core/services/ocr2/plugins/ccip/testhelpers/testhelpers_1_4_0/ccip_contracts_1_4_0.go @@ -25,6 +25,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/hashutil" "github.com/smartcontractkit/chainlink-common/pkg/merklemulti" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccip" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils" @@ -443,7 +444,7 @@ func (c *CCIPContracts) SetNopsOnRamp(t *testing.T, nopsAndWeights []evm_2_evm_o tx, err := c.Source.OnRamp.SetNops(c.Source.User, nopsAndWeights) require.NoError(t, err) c.Source.Chain.Commit() - _, err = bind.WaitMined(context.Background(), c.Source.Chain, tx) + _, err = bind.WaitMined(tests.Context(t), c.Source.Chain, tx) require.NoError(t, err) } @@ -583,7 +584,7 @@ func (c *CCIPContracts) SetupExecOCR2Config(t *testing.T, execOnchainConfig, exe func (c *CCIPContracts) SetupOnchainConfig(t *testing.T, commitOnchainConfig, commitOffchainConfig, execOnchainConfig, execOffchainConfig []byte) int64 { // Note We do NOT set the payees, payment is done in the OCR2Base implementation - blockBeforeConfig, err := c.Dest.Chain.BlockByNumber(context.Background(), nil) + blockBeforeConfig, err := c.Dest.Chain.BlockByNumber(tests.Context(t), nil) require.NoError(t, err) c.SetupCommitOCR2Config(t, commitOnchainConfig, commitOffchainConfig) @@ -1304,8 +1305,8 @@ type ManualExecArgs struct { // if the block located has a timestamp greater than the timestamp of mentioned source block // it just returns the first block found with lesser timestamp of the source block // providing a value of args.DestDeployedAt ensures better performance by reducing the range of block numbers to be traversed -func (args *ManualExecArgs) ApproxDestStartBlock() error { - sourceBlockHdr, err := args.SourceChain.HeaderByNumber(context.Background(), args.SourceStartBlock) +func (args *ManualExecArgs) ApproxDestStartBlock(ctx context.Context) error { + sourceBlockHdr, err := args.SourceChain.HeaderByNumber(ctx, args.SourceStartBlock) if err != nil { return err } @@ -1315,7 +1316,7 @@ func (args *ManualExecArgs) ApproxDestStartBlock() error { minBlockNum := args.DestDeployedAt closestBlockNum := uint64(math.Floor((float64(maxBlockNum) + float64(minBlockNum)) / 2)) var closestBlockHdr *types.Header - closestBlockHdr, err = args.DestChain.HeaderByNumber(context.Background(), big.NewInt(int64(closestBlockNum))) + closestBlockHdr, err = args.DestChain.HeaderByNumber(ctx, new(big.Int).SetUint64(closestBlockNum)) if err != nil { return err } @@ -1336,7 +1337,7 @@ func (args *ManualExecArgs) ApproxDestStartBlock() error { minBlockNum = blockNum + 1 } closestBlockNum = uint64(math.Floor((float64(maxBlockNum) + float64(minBlockNum)) / 2)) - closestBlockHdr, err = args.DestChain.HeaderByNumber(context.Background(), big.NewInt(int64(closestBlockNum))) + closestBlockHdr, err = args.DestChain.HeaderByNumber(ctx, new(big.Int).SetUint64(closestBlockNum)) if err != nil { return err } @@ -1347,7 +1348,7 @@ func (args *ManualExecArgs) ApproxDestStartBlock() error { if closestBlockNum <= 0 { return fmt.Errorf("approx destination blocknumber not found") } - closestBlockHdr, err = args.DestChain.HeaderByNumber(context.Background(), big.NewInt(int64(closestBlockNum))) + closestBlockHdr, err = args.DestChain.HeaderByNumber(ctx, new(big.Int).SetUint64(closestBlockNum)) if err != nil { return err } @@ -1383,7 +1384,7 @@ func (args *ManualExecArgs) FindSeqNrFromCCIPSendRequested() (uint64, error) { return seqNr, nil } -func (args *ManualExecArgs) ExecuteManually() (*types.Transaction, error) { +func (args *ManualExecArgs) ExecuteManually(ctx context.Context) (*types.Transaction, error) { if args.SourceChainID == 0 || args.DestChainID == 0 || args.DestUser == nil { @@ -1416,7 +1417,7 @@ func (args *ManualExecArgs) ExecuteManually() (*types.Transaction, error) { return nil, err } if args.DestStartBlock < 1 { - err = args.ApproxDestStartBlock() + err = args.ApproxDestStartBlock(ctx) if err != nil { return nil, err } @@ -1553,7 +1554,8 @@ func (c *CCIPContracts) ExecuteMessage( destStartBlock uint64, ) uint64 { t.Log("Executing request manually") - sendReqReceipt, err := c.Source.Chain.TransactionReceipt(context.Background(), txHash) + ctx := tests.Context(t) + sendReqReceipt, err := c.Source.Chain.TransactionReceipt(ctx, txHash) require.NoError(t, err) args := ManualExecArgs{ SourceChainID: c.Source.ChainID, @@ -1570,11 +1572,11 @@ func (c *CCIPContracts) ExecuteMessage( OnRamp: c.Source.OnRamp.Address().String(), OffRamp: c.Dest.OffRamp.Address().String(), } - tx, err := args.ExecuteManually() + tx, err := args.ExecuteManually(ctx) require.NoError(t, err) c.Dest.Chain.Commit() c.Source.Chain.Commit() - rec, err := c.Dest.Chain.TransactionReceipt(context.Background(), tx.Hash()) + rec, err := c.Dest.Chain.TransactionReceipt(tests.Context(t), tx.Hash()) require.NoError(t, err) require.Equal(t, uint64(1), rec.Status, "manual execution failed") t.Logf("Manual Execution completed for seqNum %d", args.SeqNr) diff --git a/core/services/ocr2/plugins/ccip/testhelpers/testhelpers_1_4_0/chainlink.go b/core/services/ocr2/plugins/ccip/testhelpers/testhelpers_1_4_0/chainlink.go index a69e284e548..b897d565bae 100644 --- a/core/services/ocr2/plugins/ccip/testhelpers/testhelpers_1_4_0/chainlink.go +++ b/core/services/ocr2/plugins/ccip/testhelpers/testhelpers_1_4_0/chainlink.go @@ -34,6 +34,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/loop" "github.com/smartcontractkit/chainlink-common/pkg/utils/mailbox" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccip" coretypes "github.com/smartcontractkit/chainlink-common/pkg/types/core/mocks" @@ -339,7 +340,7 @@ func (node *Node) AddJob(t *testing.T, spec *integrationtesthelpers.OCR2TaskJobS nil, ) require.NoError(t, err) - err = node.App.AddJobV2(context.Background(), &ccipJob) + err = node.App.AddJobV2(tests.Context(t), &ccipJob) require.NoError(t, err) } @@ -348,7 +349,7 @@ func (node *Node) AddBootstrapJob(t *testing.T, spec *integrationtesthelpers.OCR require.NoError(t, err) ccipJob, err := ocrbootstrap.ValidatedBootstrapSpecToml(specString) require.NoError(t, err) - err = node.App.AddJobV2(context.Background(), &ccipJob) + err = node.App.AddJobV2(tests.Context(t), &ccipJob) require.NoError(t, err) } @@ -508,13 +509,13 @@ func setupNodeCCIP( lggr.Debug(fmt.Sprintf("Transmitter address %s chainID %s", transmitter, s.EVMChainID.String())) // Fund the commitTransmitter address with some ETH - n, err := destChain.NonceAt(context.Background(), owner.From, nil) + n, err := destChain.NonceAt(tests.Context(t), owner.From, nil) require.NoError(t, err) tx := types3.NewTransaction(n, transmitter, big.NewInt(1000000000000000000), 21000, big.NewInt(1000000000), nil) signedTx, err := owner.Signer(owner.From, tx) require.NoError(t, err) - err = destChain.SendTransaction(context.Background(), signedTx) + err = destChain.SendTransaction(tests.Context(t), signedTx) require.NoError(t, err) destChain.Commit() @@ -944,7 +945,7 @@ func (c *CCIPIntegrationTestHarness) SetupAndStartNodes(ctx context.Context, t * func (c *CCIPIntegrationTestHarness) SetUpNodesAndJobs(t *testing.T, pricePipeline string, priceGetterConfig string, usdcAttestationAPI string) integrationtesthelpers.CCIPJobSpecParams { // setup Jobs - ctx := context.Background() + ctx := tests.Context(t) // Starts nodes and configures them in the OCR contracts. bootstrapNode, _, configBlock := c.SetupAndStartNodes(ctx, t, int64(freeport.GetOne(t))) @@ -957,7 +958,7 @@ func (c *CCIPIntegrationTestHarness) SetUpNodesAndJobs(t *testing.T, pricePipeli // Replay for bootstrap. bc, err := bootstrapNode.App.GetRelayers().LegacyEVMChains().Get(strconv.FormatUint(c.Dest.ChainID, 10)) require.NoError(t, err) - require.NoError(t, bc.LogPoller().Replay(context.Background(), configBlock)) + require.NoError(t, bc.LogPoller().Replay(tests.Context(t), configBlock)) c.Dest.Chain.Commit() return jobParams diff --git a/core/services/ocr2/plugins/ccip/tokendata/bgworker.go b/core/services/ocr2/plugins/ccip/tokendata/bgworker.go index 1a74ab2305b..bc5aba557e6 100644 --- a/core/services/ocr2/plugins/ccip/tokendata/bgworker.go +++ b/core/services/ocr2/plugins/ccip/tokendata/bgworker.go @@ -41,9 +41,8 @@ type BackgroundWorker struct { timeoutDur time.Duration services.StateMachine - wg *sync.WaitGroup - backgroundCtx context.Context //nolint:containedctx - backgroundCancel context.CancelFunc + wg sync.WaitGroup + stopChan services.StopChan } func NewBackgroundWorker( @@ -56,17 +55,13 @@ func NewBackgroundWorker( expirationDur = 24 * time.Hour } - ctx, cancel := context.WithCancel(context.Background()) return &BackgroundWorker{ tokenDataReaders: tokenDataReaders, numWorkers: numWorkers, jobsChan: make(chan cciptypes.EVM2EVMOnRampCCIPSendRequestedWithMeta, numWorkers*100), resultsCache: cache.New(expirationDur, expirationDur/2), timeoutDur: timeoutDur, - - wg: new(sync.WaitGroup), - backgroundCtx: ctx, - backgroundCancel: cancel, + stopChan: make(services.StopChan), } } @@ -82,7 +77,7 @@ func (w *BackgroundWorker) Start(context.Context) error { func (w *BackgroundWorker) Close() error { return w.StateMachine.StopOnce("Token BackgroundWorker", func() error { - w.backgroundCancel() + close(w.stopChan) w.wg.Wait() return nil }) @@ -90,12 +85,13 @@ func (w *BackgroundWorker) Close() error { func (w *BackgroundWorker) AddJobsFromMsgs(ctx context.Context, msgs []cciptypes.EVM2EVMOnRampCCIPSendRequestedWithMeta) { w.wg.Add(1) - go func() { + go func(ctx context.Context) { defer w.wg.Done() + ctx, cancel := w.stopChan.Ctx(ctx) + defer cancel() + for _, msg := range msgs { select { - case <-w.backgroundCtx.Done(): - return case <-ctx.Done(): return default: @@ -104,7 +100,7 @@ func (w *BackgroundWorker) AddJobsFromMsgs(ctx context.Context, msgs []cciptypes } } } - }() + }(ctx) } func (w *BackgroundWorker) GetReaders() map[cciptypes.Address]Reader { @@ -134,12 +130,15 @@ func (w *BackgroundWorker) GetMsgTokenData(ctx context.Context, msg cciptypes.EV func (w *BackgroundWorker) run() { go func() { defer w.wg.Done() + ctx, cancel := w.stopChan.NewCtx() + defer cancel() + for { select { - case <-w.backgroundCtx.Done(): + case <-ctx.Done(): return case msg := <-w.jobsChan: - w.workOnMsg(w.backgroundCtx, msg) + w.workOnMsg(ctx, msg) } } }() diff --git a/core/services/ocr2/plugins/ccip/tokendata/usdc/usdc_test.go b/core/services/ocr2/plugins/ccip/tokendata/usdc/usdc_test.go index 4210ecf75ea..786e88a6322 100644 --- a/core/services/ocr2/plugins/ccip/tokendata/usdc/usdc_test.go +++ b/core/services/ocr2/plugins/ccip/tokendata/usdc/usdc_test.go @@ -20,9 +20,11 @@ import ( "github.com/stretchr/testify/require" cciptypes "github.com/smartcontractkit/chainlink-common/pkg/types/ccip" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller/mocks" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils" + "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/ccipcalc" "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/plugins/ccip/internal/ccipdata" @@ -35,15 +37,17 @@ var ( ) func TestUSDCReader_callAttestationApi(t *testing.T) { + ctx := tests.Context(t) //nolint:staticcheck // SA4006 - false positive "unused" t.Skipf("Skipping test because it uses the real USDC attestation API") usdcMessageHash := "912f22a13e9ccb979b621500f6952b2afd6e75be7eadaed93fc2625fe11c52a2" attestationURI, err := url.ParseRequestURI("https://iris-api-sandbox.circle.com") require.NoError(t, err) lggr := logger.TestLogger(t) - usdcReader, _ := ccipdata.NewUSDCReader(lggr, "job_123", mockMsgTransmitter, nil, false) + usdcReader, err := ccipdata.NewUSDCReader(ctx, lggr, "job_123", mockMsgTransmitter, nil, false) + require.NoError(t, err) usdcService := NewUSDCTokenDataReader(lggr, usdcReader, attestationURI, 0, common.Address{}, APIIntervalRateLimitDisabled) - attestation, err := usdcService.callAttestationApi(context.Background(), [32]byte(common.FromHex(usdcMessageHash))) + attestation, err := usdcService.callAttestationApi(ctx, [32]byte(common.FromHex(usdcMessageHash))) require.NoError(t, err) require.Equal(t, attestationStatusPending, attestation.Status) @@ -52,6 +56,7 @@ func TestUSDCReader_callAttestationApi(t *testing.T) { func TestUSDCReader_callAttestationApiMock(t *testing.T) { t.Parallel() + ctx := tests.Context(t) response := attestationResponse{ Status: attestationStatusSuccess, Attestation: "720502893578a89a8a87982982ef781c18b193", @@ -64,9 +69,9 @@ func TestUSDCReader_callAttestationApiMock(t *testing.T) { lggr := logger.TestLogger(t) lp := mocks.NewLogPoller(t) - usdcReader, _ := ccipdata.NewUSDCReader(lggr, "job_123", mockMsgTransmitter, lp, false) + usdcReader, _ := ccipdata.NewUSDCReader(ctx, lggr, "job_123", mockMsgTransmitter, lp, false) usdcService := NewUSDCTokenDataReader(lggr, usdcReader, attestationURI, 0, common.Address{}, APIIntervalRateLimitDisabled) - attestation, err := usdcService.callAttestationApi(context.Background(), utils.RandomBytes32()) + attestation, err := usdcService.callAttestationApi(ctx, utils.RandomBytes32()) require.NoError(t, err) require.Equal(t, response.Status, attestation.Status) @@ -196,12 +201,13 @@ func TestUSDCReader_callAttestationApiMockError(t *testing.T) { lggr := logger.TestLogger(t) lp := mocks.NewLogPoller(t) - usdcReader, _ := ccipdata.NewUSDCReader(lggr, "job_123", mockMsgTransmitter, lp, false) + ctx := testutils.Context(t) + usdcReader, _ := ccipdata.NewUSDCReader(ctx, lggr, "job_123", mockMsgTransmitter, lp, false) usdcService := NewUSDCTokenDataReader(lggr, usdcReader, attestationURI, test.customTimeoutSeconds, common.Address{}, APIIntervalRateLimitDisabled) lp.On("RegisterFilter", mock.Anything, mock.Anything).Return(nil) - require.NoError(t, usdcReader.RegisterFilters()) + require.NoError(t, usdcReader.RegisterFilters(ctx)) - parentCtx, cancel := context.WithTimeout(context.Background(), time.Duration(test.parentTimeoutSeconds)*time.Second) + parentCtx, cancel := context.WithTimeout(ctx, time.Duration(test.parentTimeoutSeconds)*time.Second) defer cancel() _, err = usdcService.callAttestationApi(parentCtx, utils.RandomBytes32()) @@ -228,6 +234,7 @@ func getMockUSDCEndpoint(t *testing.T, response attestationResponse) *httptest.S func TestGetUSDCMessageBody(t *testing.T) { t.Parallel() + ctx := testutils.Context(t) expectedBody := []byte("0x0000000000000001000000020000000000048d71000000000000000000000000eb08f243e5d3fcff26a9e38ae5520a669f4019d000000000000000000000000023a04d5935ed8bc8e3eb78db3541f0abfb001c6e0000000000000000000000006cb3ed9b441eb674b58495c8b3324b59faff5243000000000000000000000000000000005425890298aed601595a70ab815c96711a31bc65000000000000000000000000ab4f961939bfe6a93567cc57c59eed7084ce2131000000000000000000000000000000000000000000000000000000000000271000000000000000000000000035e08285cfed1ef159236728f843286c55fc0861") usdcReader := ccipdatamocks.USDCReader{} usdcReader.On("GetUSDCMessagePriorToLogIndexInTx", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(expectedBody, nil) @@ -237,7 +244,7 @@ func TestGetUSDCMessageBody(t *testing.T) { usdcService := NewUSDCTokenDataReader(lggr, &usdcReader, nil, 0, usdcTokenAddr, APIIntervalRateLimitDisabled) // Make the first call and assert the underlying function is called - body, err := usdcService.getUSDCMessageBody(context.Background(), cciptypes.EVM2EVMOnRampCCIPSendRequestedWithMeta{ + body, err := usdcService.getUSDCMessageBody(ctx, cciptypes.EVM2EVMOnRampCCIPSendRequestedWithMeta{ EVM2EVMMessage: cciptypes.EVM2EVMMessage{ TokenAmounts: []cciptypes.TokenAmount{ { @@ -356,6 +363,7 @@ func TestUSDCReader_rateLimiting(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() + ctx := tests.Context(t) response := attestationResponse{ Status: attestationStatusSuccess, @@ -369,10 +377,9 @@ func TestUSDCReader_rateLimiting(t *testing.T) { lggr := logger.TestLogger(t) lp := mocks.NewLogPoller(t) - usdcReader, _ := ccipdata.NewUSDCReader(lggr, "job_123", mockMsgTransmitter, lp, false) + usdcReader, _ := ccipdata.NewUSDCReader(ctx, lggr, "job_123", mockMsgTransmitter, lp, false) usdcService := NewUSDCTokenDataReader(lggr, usdcReader, attestationURI, 0, utils.RandomAddress(), tc.rateConfig) - ctx := context.Background() if tc.timeout > 0 { var cf context.CancelFunc ctx, cf = context.WithTimeout(ctx, tc.timeout) diff --git a/core/services/ocr2/plugins/llo/integration_test.go b/core/services/ocr2/plugins/llo/integration_test.go index 7ab735bf122..206f8012e8b 100644 --- a/core/services/ocr2/plugins/llo/integration_test.go +++ b/core/services/ocr2/plugins/llo/integration_test.go @@ -323,6 +323,7 @@ func promoteStagingConfig(t *testing.T, donID uint32, steve *bind.TransactOpts, } func TestIntegration_LLO(t *testing.T) { + t.Parallel() testStartTimeStamp := time.Now() multiplier := decimal.New(1, 18) expirationWindow := time.Hour / time.Second @@ -808,7 +809,8 @@ func setupNodes(t *testing.T, nNodes int, backend *backends.SimulatedBackend, cl nodes = append(nodes, Node{ app, transmitter, kb, observedLogs, }) - offchainPublicKey, _ := hex.DecodeString(strings.TrimPrefix(kb.OnChainPublicKey(), "0x")) + offchainPublicKey, err := hex.DecodeString(strings.TrimPrefix(kb.OnChainPublicKey(), "0x")) + require.NoError(t, err) oracles = append(oracles, confighelper.OracleIdentityExtra{ OracleIdentity: confighelper.OracleIdentity{ OnchainPublicKey: offchainPublicKey, diff --git a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/mercury/v02/request.go b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/mercury/v02/request.go index c02b7c10de5..4adef132aab 100644 --- a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/mercury/v02/request.go +++ b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/mercury/v02/request.go @@ -70,8 +70,7 @@ func (c *client) DoRequest(ctx context.Context, streamsLookup *mercury.StreamsLo }) } - // TODO (AUTO 9090): Understand and fix the use of context.Background() here - reqTimeoutCtx, cancel := context.WithTimeout(context.Background(), mercury.RequestTimeout) + ctx, cancel := context.WithTimeout(ctx, mercury.RequestTimeout) defer cancel() state := encoding.NoPipelineError @@ -86,7 +85,7 @@ func (c *client) DoRequest(ctx context.Context, streamsLookup *mercury.StreamsLo // if no execution errors, then check if any feed returned an error code, if so use the last error code for i := 0; i < resultLen; i++ { select { - case <-reqTimeoutCtx.Done(): + case <-ctx.Done(): // Request Timed out, return timeout error c.lggr.Errorf("at block %s upkeep %s, streams lookup v0.2 timed out", streamsLookup.Time.String(), streamsLookup.UpkeepId.String()) return encoding.NoPipelineError, nil, encoding.ErrCodeStreamsTimeout, false, 0 * time.Second, nil diff --git a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/mercury/v03/request.go b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/mercury/v03/request.go index 3ade8cc7261..16892c88a59 100644 --- a/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/mercury/v03/request.go +++ b/core/services/ocr2/plugins/ocr2keeper/evmregistry/v21/mercury/v03/request.go @@ -74,11 +74,10 @@ func (c *client) DoRequest(ctx context.Context, streamsLookup *mercury.StreamsLo c.multiFeedsRequest(ctx, ch, streamsLookup) }) - // TODO (AUTO 9090): Understand and fix the use of context.Background() here - reqTimeoutCtx, cancel := context.WithTimeout(context.Background(), mercury.RequestTimeout) + ctx, cancel := context.WithTimeout(ctx, mercury.RequestTimeout) defer cancel() select { - case <-reqTimeoutCtx.Done(): + case <-ctx.Done(): // Request Timed out, return timeout error c.lggr.Errorf("at timestamp %s upkeep %s, streams lookup v0.3 timed out", streamsLookup.Time.String(), streamsLookup.UpkeepId.String()) return encoding.NoPipelineError, nil, encoding.ErrCodeStreamsTimeout, false, 0 * time.Second, nil diff --git a/core/services/pipeline/runner.go b/core/services/pipeline/runner.go index 185504fc0e4..1fc2fc46336 100644 --- a/core/services/pipeline/runner.go +++ b/core/services/pipeline/runner.go @@ -384,7 +384,7 @@ func (r *runner) run(ctx context.Context, pipeline *Pipeline, run *Run, vars Var // This is "just in case" for cleaning up any stray reports. // Normally the scheduler loop doesn't stop until all in progress runs report back - reportCtx, cancel := context.WithCancel(context.Background()) + reportCtx, cancel := context.WithCancel(context.WithoutCancel(ctx)) defer cancel() if pipelineTimeout := r.config.MaxRunDuration(); pipelineTimeout != 0 { diff --git a/core/services/relay/evm/ccip.go b/core/services/relay/evm/ccip.go index 3eefb7bec7b..a06f60c6fd4 100644 --- a/core/services/relay/evm/ccip.go +++ b/core/services/relay/evm/ccip.go @@ -131,8 +131,8 @@ type IncompleteDestCommitStoreReader struct { cs cciptypes.CommitStoreReader } -func NewIncompleteDestCommitStoreReader(lggr logger.Logger, versionFinder ccip.VersionFinder, address cciptypes.Address, ec client.Client, lp logpoller.LogPoller) (*IncompleteDestCommitStoreReader, error) { - cs, err := ccip.NewCommitStoreReader(lggr, versionFinder, address, ec, lp) +func NewIncompleteDestCommitStoreReader(ctx context.Context, lggr logger.Logger, versionFinder ccip.VersionFinder, address cciptypes.Address, ec client.Client, lp logpoller.LogPoller) (*IncompleteDestCommitStoreReader, error) { + cs, err := ccip.NewCommitStoreReader(ctx, lggr, versionFinder, address, ec, lp) if err != nil { return nil, err } diff --git a/core/services/relay/evm/commit_provider.go b/core/services/relay/evm/commit_provider.go index 71ac3846395..1a9260120f3 100644 --- a/core/services/relay/evm/commit_provider.go +++ b/core/services/relay/evm/commit_provider.go @@ -114,7 +114,7 @@ func (p *SrcCommitProvider) Close() error { if p.seenOnRampAddress == nil { return nil } - return ccip.CloseOnRampReader(p.lggr, versionFinder, *p.seenSourceChainSelector, *p.seenDestChainSelector, *p.seenOnRampAddress, p.lp, p.client) + return ccip.CloseOnRampReader(context.Background(), p.lggr, versionFinder, *p.seenSourceChainSelector, *p.seenDestChainSelector, *p.seenOnRampAddress, p.lp, p.client) }) var multiErr error @@ -165,25 +165,26 @@ func (p *DstCommitProvider) Name() string { } func (p *DstCommitProvider) Close() error { + ctx := context.Background() versionFinder := ccip.NewEvmVersionFinder() - unregisterFuncs := make([]func() error, 0, 2) - unregisterFuncs = append(unregisterFuncs, func() error { + unregisterFuncs := make([]func(ctx context.Context) error, 0, 2) + unregisterFuncs = append(unregisterFuncs, func(ctx context.Context) error { if p.seenCommitStoreAddress == nil { return nil } - return ccip.CloseCommitStoreReader(p.lggr, versionFinder, *p.seenCommitStoreAddress, p.client, p.lp) + return ccip.CloseCommitStoreReader(ctx, p.lggr, versionFinder, *p.seenCommitStoreAddress, p.client, p.lp) }) - unregisterFuncs = append(unregisterFuncs, func() error { + unregisterFuncs = append(unregisterFuncs, func(ctx context.Context) error { if p.seenOffRampAddress == nil { return nil } - return ccip.CloseOffRampReader(p.lggr, versionFinder, *p.seenOffRampAddress, p.client, p.lp, nil, big.NewInt(0)) + return ccip.CloseOffRampReader(ctx, p.lggr, versionFinder, *p.seenOffRampAddress, p.client, p.lp, nil, big.NewInt(0)) }) var multiErr error for _, fn := range unregisterFuncs { - if err := fn(); err != nil { + if err := fn(ctx); err != nil { multiErr = multierr.Append(multiErr, err) } } @@ -257,7 +258,7 @@ func (p *DstCommitProvider) NewCommitStoreReader(ctx context.Context, commitStor p.seenCommitStoreAddress = &commitStoreAddress versionFinder := ccip.NewEvmVersionFinder() - commitStoreReader, err = NewIncompleteDestCommitStoreReader(p.lggr, versionFinder, commitStoreAddress, p.client, p.lp) + commitStoreReader, err = NewIncompleteDestCommitStoreReader(ctx, p.lggr, versionFinder, commitStoreAddress, p.client, p.lp) return } @@ -267,7 +268,7 @@ func (p *SrcCommitProvider) NewOnRampReader(ctx context.Context, onRampAddress c p.seenDestChainSelector = &destChainSelector versionFinder := ccip.NewEvmVersionFinder() - onRampReader, err = ccip.NewOnRampReader(p.lggr, versionFinder, sourceChainSelector, destChainSelector, onRampAddress, p.lp, p.client) + onRampReader, err = ccip.NewOnRampReader(ctx, p.lggr, versionFinder, sourceChainSelector, destChainSelector, onRampAddress, p.lp, p.client) return } @@ -280,7 +281,7 @@ func (p *SrcCommitProvider) NewOffRampReader(ctx context.Context, offRampAddr cc } func (p *DstCommitProvider) NewOffRampReader(ctx context.Context, offRampAddr cciptypes.Address) (offRampReader cciptypes.OffRampReader, err error) { - offRampReader, err = ccip.NewOffRampReader(p.lggr, p.versionFinder, offRampAddr, p.client, p.lp, p.gasEstimator, &p.maxGasPrice, true) + offRampReader, err = ccip.NewOffRampReader(ctx, p.lggr, p.versionFinder, offRampAddr, p.client, p.lp, p.gasEstimator, &p.maxGasPrice, true) return } diff --git a/core/services/relay/evm/evm.go b/core/services/relay/evm/evm.go index 43b8408d2ee..7c380211ea0 100644 --- a/core/services/relay/evm/evm.go +++ b/core/services/relay/evm/evm.go @@ -572,7 +572,7 @@ func (r *Relayer) NewLLOProvider(ctx context.Context, rargs commontypes.RelayArg } configuratorAddress := common.HexToAddress(relayOpts.ContractID) - return NewLLOProvider(context.Background(), transmitter, r.lggr, r.retirementReportCache, r.chain, configuratorAddress, cdc, relayConfig, relayOpts) + return NewLLOProvider(ctx, transmitter, r.lggr, r.retirementReportCache, r.chain, configuratorAddress, cdc, relayConfig, relayOpts) } func (r *Relayer) NewFunctionsProvider(ctx context.Context, rargs commontypes.RelayArgs, pargs commontypes.PluginArgs) (commontypes.FunctionsProvider, error) { @@ -1022,6 +1022,7 @@ func (r *Relayer) NewCCIPExecProvider(ctx context.Context, rargs commontypes.Rel // bail early. if execPluginConfig.IsSourceProvider { return NewSrcExecProvider( + ctx, r.lggr, versionFinder, r.chain.Client(), diff --git a/core/services/relay/evm/exec_provider.go b/core/services/relay/evm/exec_provider.go index 98f85c23aa8..da190d20356 100644 --- a/core/services/relay/evm/exec_provider.go +++ b/core/services/relay/evm/exec_provider.go @@ -48,6 +48,7 @@ type SrcExecProvider struct { } func NewSrcExecProvider( + ctx context.Context, lggr logger.Logger, versionFinder ccip.VersionFinder, client client.Client, @@ -64,7 +65,7 @@ func NewSrcExecProvider( var usdcReader *ccip.USDCReaderImpl var err error if usdcAttestationAPI != "" { - usdcReader, err = ccip.NewUSDCReader(lggr, jobID, usdcSrcMsgTransmitterAddr, lp, true) + usdcReader, err = ccip.NewUSDCReader(ctx, lggr, jobID, usdcSrcMsgTransmitterAddr, lp, true) if err != nil { return nil, fmt.Errorf("new usdc reader: %w", err) } @@ -100,25 +101,26 @@ func (s *SrcExecProvider) Start(ctx context.Context) error { // Close is called when the job that created this provider is closed. func (s *SrcExecProvider) Close() error { + ctx := context.Background() versionFinder := ccip.NewEvmVersionFinder() - unregisterFuncs := make([]func() error, 0, 2) - unregisterFuncs = append(unregisterFuncs, func() error { + unregisterFuncs := make([]func(context.Context) error, 0, 2) + unregisterFuncs = append(unregisterFuncs, func(ctx context.Context) error { // avoid panic in the case NewOnRampReader wasn't called if s.seenOnRampAddress == nil { return nil } - return ccip.CloseOnRampReader(s.lggr, versionFinder, *s.seenSourceChainSelector, *s.seenDestChainSelector, *s.seenOnRampAddress, s.lp, s.client) + return ccip.CloseOnRampReader(ctx, s.lggr, versionFinder, *s.seenSourceChainSelector, *s.seenDestChainSelector, *s.seenOnRampAddress, s.lp, s.client) }) - unregisterFuncs = append(unregisterFuncs, func() error { + unregisterFuncs = append(unregisterFuncs, func(ctx context.Context) error { if s.usdcAttestationAPI == "" { return nil } - return ccip.CloseUSDCReader(s.lggr, s.lggr.Name(), s.usdcSrcMsgTransmitterAddr, s.lp) + return ccip.CloseUSDCReader(ctx, s.lggr, s.lggr.Name(), s.usdcSrcMsgTransmitterAddr, s.lp) }) var multiErr error for _, fn := range unregisterFuncs { - if err := fn(); err != nil { + if err := fn(ctx); err != nil { multiErr = multierr.Append(multiErr, err) } } @@ -176,7 +178,7 @@ func (s *SrcExecProvider) NewOnRampReader(ctx context.Context, onRampAddress cci s.seenOnRampAddress = &onRampAddress versionFinder := ccip.NewEvmVersionFinder() - onRampReader, err = ccip.NewOnRampReader(s.lggr, versionFinder, sourceChainSelector, destChainSelector, onRampAddress, s.lp, s.client) + onRampReader, err = ccip.NewOnRampReader(ctx, s.lggr, versionFinder, sourceChainSelector, destChainSelector, onRampAddress, s.lp, s.client) return } @@ -289,22 +291,23 @@ func (d *DstExecProvider) Start(ctx context.Context) error { // If NewOnRampReader and NewCommitStoreReader have not been called, their corresponding // Close methods will be expected to error. func (d *DstExecProvider) Close() error { + ctx := context.Background() versionFinder := ccip.NewEvmVersionFinder() - unregisterFuncs := make([]func() error, 0, 2) - unregisterFuncs = append(unregisterFuncs, func() error { + unregisterFuncs := make([]func(context.Context) error, 0, 2) + unregisterFuncs = append(unregisterFuncs, func(ctx context.Context) error { if d.seenCommitStoreAddr == nil { return nil } - return ccip.CloseCommitStoreReader(d.lggr, versionFinder, *d.seenCommitStoreAddr, d.client, d.lp) + return ccip.CloseCommitStoreReader(ctx, d.lggr, versionFinder, *d.seenCommitStoreAddr, d.client, d.lp) }) - unregisterFuncs = append(unregisterFuncs, func() error { - return ccip.CloseOffRampReader(d.lggr, versionFinder, d.offRampAddress, d.client, d.lp, nil, big.NewInt(0)) + unregisterFuncs = append(unregisterFuncs, func(ctx context.Context) error { + return ccip.CloseOffRampReader(ctx, d.lggr, versionFinder, d.offRampAddress, d.client, d.lp, nil, big.NewInt(0)) }) var multiErr error for _, fn := range unregisterFuncs { - if err := fn(); err != nil { + if err := fn(ctx); err != nil { multiErr = multierr.Append(multiErr, err) } } @@ -347,12 +350,12 @@ func (d *DstExecProvider) NewCommitStoreReader(ctx context.Context, addr cciptyp d.seenCommitStoreAddr = &addr versionFinder := ccip.NewEvmVersionFinder() - commitStoreReader, err = NewIncompleteDestCommitStoreReader(d.lggr, versionFinder, addr, d.client, d.lp) + commitStoreReader, err = NewIncompleteDestCommitStoreReader(ctx, d.lggr, versionFinder, addr, d.client, d.lp) return } func (d *DstExecProvider) NewOffRampReader(ctx context.Context, offRampAddress cciptypes.Address) (offRampReader cciptypes.OffRampReader, err error) { - offRampReader, err = ccip.NewOffRampReader(d.lggr, d.versionFinder, offRampAddress, d.client, d.lp, d.gasEstimator, &d.maxGasPrice, true) + offRampReader, err = ccip.NewOffRampReader(ctx, d.lggr, d.versionFinder, offRampAddress, d.client, d.lp, d.gasEstimator, &d.maxGasPrice, true) return } diff --git a/core/services/relay/evm/mercury/persistence_manager.go b/core/services/relay/evm/mercury/persistence_manager.go index dfe75e7c3ce..68137d04c14 100644 --- a/core/services/relay/evm/mercury/persistence_manager.go +++ b/core/services/relay/evm/mercury/persistence_manager.go @@ -87,7 +87,7 @@ func (pm *PersistenceManager) Load(ctx context.Context) ([]*Transmission, error) func (pm *PersistenceManager) runFlushDeletesLoop() { defer pm.wg.Done() - ctx, cancel := pm.stopCh.Ctx(context.Background()) + ctx, cancel := pm.stopCh.NewCtx() defer cancel() ticker := services.NewTicker(pm.flushDeletesFrequency) diff --git a/core/services/relay/evm/mercury/transmitter.go b/core/services/relay/evm/mercury/transmitter.go index b55cc8cf028..4e57a3d07cf 100644 --- a/core/services/relay/evm/mercury/transmitter.go +++ b/core/services/relay/evm/mercury/transmitter.go @@ -179,7 +179,7 @@ func (s *server) HealthReport() map[string]error { func (s *server) runDeleteQueueLoop(stopCh services.StopChan, wg *sync.WaitGroup) { defer wg.Done() - runloopCtx, cancel := stopCh.Ctx(context.Background()) + ctx, cancel := stopCh.NewCtx() defer cancel() // Exponential backoff for very rarely occurring errors (DB disconnect etc) @@ -194,7 +194,7 @@ func (s *server) runDeleteQueueLoop(stopCh services.StopChan, wg *sync.WaitGroup select { case req := <-s.deleteQueue: for { - if err := s.pm.Delete(runloopCtx, req); err != nil { + if err := s.pm.Delete(ctx, req); err != nil { s.lggr.Errorw("Failed to delete transmit request record", "err", err, "req.Payload", req.Payload) s.transmitQueueDeleteErrorCount.Inc() select { @@ -227,7 +227,7 @@ func (s *server) runQueueLoop(stopCh services.StopChan, wg *sync.WaitGroup, feed Factor: 2, Jitter: true, } - runloopCtx, cancel := stopCh.Ctx(context.Background()) + ctx, cancel := stopCh.NewCtx() defer cancel() for { t := s.q.BlockingPop() @@ -235,12 +235,13 @@ func (s *server) runQueueLoop(stopCh services.StopChan, wg *sync.WaitGroup, feed // queue was closed return } - ctx, cancel := context.WithTimeout(runloopCtx, utils.WithJitter(s.transmitTimeout)) - res, err := s.c.Transmit(ctx, t.Req) - cancel() - if runloopCtx.Err() != nil { - // runloop context is only canceled on transmitter close so we can - // exit the runloop here + res, err := func(ctx context.Context) (*pb.TransmitResponse, error) { + ctx, cancel := context.WithTimeout(ctx, utils.WithJitter(s.transmitTimeout)) + cancel() + return s.c.Transmit(ctx, t.Req) + }(ctx) + if ctx.Err() != nil { + // only canceled on transmitter close so we can exit return } else if err != nil { s.transmitConnectionErrorCount.Inc() diff --git a/core/services/relay/evm/mercury/wsrpc/client.go b/core/services/relay/evm/mercury/wsrpc/client.go index 37207510655..c87b555e6a5 100644 --- a/core/services/relay/evm/mercury/wsrpc/client.go +++ b/core/services/relay/evm/mercury/wsrpc/client.go @@ -189,7 +189,7 @@ func (w *client) resetTransport() { if !ok { panic("resetTransport should never be called unless client is in 'started' state") } - ctx, cancel := w.chStop.Ctx(context.Background()) + ctx, cancel := w.chStop.NewCtx() defer cancel() b := utils.NewRedialBackoff() for { diff --git a/core/sessions/ldapauth/sync.go b/core/sessions/ldapauth/sync.go index 5eeaf051526..e3ac8898101 100644 --- a/core/sessions/ldapauth/sync.go +++ b/core/sessions/ldapauth/sync.go @@ -9,8 +9,8 @@ import ( "github.com/go-ldap/ldap/v3" "github.com/lib/pq" + "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" - "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink/v2/core/config" "github.com/smartcontractkit/chainlink/v2/core/logger" "github.com/smartcontractkit/chainlink/v2/core/sessions" @@ -22,110 +22,135 @@ type LDAPServerStateSyncer struct { config config.LDAP lggr logger.Logger nextSyncTime time.Time + done chan struct{} + stopCh services.StopChan } -// NewLDAPServerStateSync creates a reaper that cleans stale sessions from the store. -func NewLDAPServerStateSync( +// NewLDAPServerStateSyncer creates a reaper that cleans stale sessions from the store. +func NewLDAPServerStateSyncer( ds sqlutil.DataSource, config config.LDAP, lggr logger.Logger, -) *utils.SleeperTask { - namedLogger := lggr.Named("LDAPServerStateSync") - serverSync := LDAPServerStateSyncer{ - ds: ds, - ldapClient: newLDAPClient(config), - config: config, - lggr: namedLogger, - nextSyncTime: time.Time{}, +) *LDAPServerStateSyncer { + return &LDAPServerStateSyncer{ + ds: ds, + ldapClient: newLDAPClient(config), + config: config, + lggr: lggr.Named("LDAPServerStateSync"), + done: make(chan struct{}), + stopCh: make(services.StopChan), } +} + +func (l *LDAPServerStateSyncer) Name() string { + return l.lggr.Name() +} + +func (l *LDAPServerStateSyncer) Ready() error { return nil } + +func (l *LDAPServerStateSyncer) HealthReport() map[string]error { + return map[string]error{l.Name(): nil} +} + +func (l *LDAPServerStateSyncer) Start(ctx context.Context) error { // If enabled, start a background task that calls the Sync/Work function on an // interval without needing an auth event to trigger it // Use IsInstant to check 0 value to omit functionality. - if !config.UpstreamSyncInterval().IsInstant() { - lggr.Info("LDAP Config UpstreamSyncInterval is non-zero, sync functionality will be called on a timer, respecting the UpstreamSyncRateLimit value") - serverSync.StartWorkOnTimer() + if !l.config.UpstreamSyncInterval().IsInstant() { + l.lggr.Info("LDAP Config UpstreamSyncInterval is non-zero, sync functionality will be called on a timer, respecting the UpstreamSyncRateLimit value") + go l.run() } else { // Ensure upstream server state is synced on startup manually if interval check not set - serverSync.Work() + l.Work(ctx) } - - // Start background Sync call task reactive to auth related events - serverSyncSleeperTask := utils.NewSleeperTask(&serverSync) - return serverSyncSleeperTask + return nil } -func (ldSync *LDAPServerStateSyncer) Name() string { - return "LDAPServerStateSync" +func (l *LDAPServerStateSyncer) Close() error { + close(l.stopCh) + <-l.done + return nil } -func (ldSync *LDAPServerStateSyncer) StartWorkOnTimer() { - time.AfterFunc(ldSync.config.UpstreamSyncInterval().Duration(), ldSync.StartWorkOnTimer) - ldSync.Work() +func (l *LDAPServerStateSyncer) run() { + defer close(l.done) + ctx, cancel := l.stopCh.NewCtx() + defer cancel() + ticker := time.NewTicker(l.config.UpstreamSyncInterval().Duration()) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + l.Work(ctx) + } + } } -func (ldSync *LDAPServerStateSyncer) Work() { - ctx := context.Background() // TODO https://smartcontract-it.atlassian.net/browse/BCF-2887 +func (l *LDAPServerStateSyncer) Work(ctx context.Context) { // Purge expired ldap_sessions and ldap_user_api_tokens - recordCreationStaleThreshold := ldSync.config.SessionTimeout().Before(time.Now()) - err := ldSync.deleteStaleSessions(ctx, recordCreationStaleThreshold) + recordCreationStaleThreshold := l.config.SessionTimeout().Before(time.Now()) + err := l.deleteStaleSessions(ctx, recordCreationStaleThreshold) if err != nil { - ldSync.lggr.Error("unable to expire local LDAP sessions: ", err) + l.lggr.Error("unable to expire local LDAP sessions: ", err) } - recordCreationStaleThreshold = ldSync.config.UserAPITokenDuration().Before(time.Now()) - err = ldSync.deleteStaleAPITokens(ctx, recordCreationStaleThreshold) + recordCreationStaleThreshold = l.config.UserAPITokenDuration().Before(time.Now()) + err = l.deleteStaleAPITokens(ctx, recordCreationStaleThreshold) if err != nil { - ldSync.lggr.Error("unable to expire user API tokens: ", err) + l.lggr.Error("unable to expire user API tokens: ", err) } // Optional rate limiting check to limit the amount of upstream LDAP server queries performed - if !ldSync.config.UpstreamSyncRateLimit().IsInstant() { - if !time.Now().After(ldSync.nextSyncTime) { + if !l.config.UpstreamSyncRateLimit().IsInstant() { + if !time.Now().After(l.nextSyncTime) { return } // Enough time has elapsed to sync again, store the time for when next sync is allowed and begin sync - ldSync.nextSyncTime = time.Now().Add(ldSync.config.UpstreamSyncRateLimit().Duration()) + l.nextSyncTime = time.Now().Add(l.config.UpstreamSyncRateLimit().Duration()) } - ldSync.lggr.Info("Begin Upstream LDAP provider state sync after checking time against config UpstreamSyncInterval and UpstreamSyncRateLimit") + l.lggr.Info("Begin Upstream LDAP provider state sync after checking time against config UpstreamSyncInterval and UpstreamSyncRateLimit") // For each defined role/group, query for the list of group members to gather the full list of possible users users := []sessions.User{} - conn, err := ldSync.ldapClient.CreateEphemeralConnection() + conn, err := l.ldapClient.CreateEphemeralConnection() if err != nil { - ldSync.lggr.Error("Failed to Dial LDAP Server: ", err) + l.lggr.Error("Failed to Dial LDAP Server: ", err) return } // Root level root user auth with credentials provided from config - bindStr := ldSync.config.BaseUserAttr() + "=" + ldSync.config.ReadOnlyUserLogin() + "," + ldSync.config.BaseDN() - if err = conn.Bind(bindStr, ldSync.config.ReadOnlyUserPass()); err != nil { - ldSync.lggr.Error("Unable to login as initial root LDAP user: ", err) + bindStr := l.config.BaseUserAttr() + "=" + l.config.ReadOnlyUserLogin() + "," + l.config.BaseDN() + if err = conn.Bind(bindStr, l.config.ReadOnlyUserPass()); err != nil { + l.lggr.Error("Unable to login as initial root LDAP user: ", err) } defer conn.Close() // Query for list of uniqueMember IDs present in Admin group - adminUsers, err := ldSync.ldapGroupMembersListToUser(conn, ldSync.config.AdminUserGroupCN(), sessions.UserRoleAdmin) + adminUsers, err := l.ldapGroupMembersListToUser(conn, l.config.AdminUserGroupCN(), sessions.UserRoleAdmin) if err != nil { - ldSync.lggr.Error("Error in ldapGroupMembersListToUser: ", err) + l.lggr.Error("Error in ldapGroupMembersListToUser: ", err) return } // Query for list of uniqueMember IDs present in Edit group - editUsers, err := ldSync.ldapGroupMembersListToUser(conn, ldSync.config.EditUserGroupCN(), sessions.UserRoleEdit) + editUsers, err := l.ldapGroupMembersListToUser(conn, l.config.EditUserGroupCN(), sessions.UserRoleEdit) if err != nil { - ldSync.lggr.Error("Error in ldapGroupMembersListToUser: ", err) + l.lggr.Error("Error in ldapGroupMembersListToUser: ", err) return } // Query for list of uniqueMember IDs present in Edit group - runUsers, err := ldSync.ldapGroupMembersListToUser(conn, ldSync.config.RunUserGroupCN(), sessions.UserRoleRun) + runUsers, err := l.ldapGroupMembersListToUser(conn, l.config.RunUserGroupCN(), sessions.UserRoleRun) if err != nil { - ldSync.lggr.Error("Error in ldapGroupMembersListToUser: ", err) + l.lggr.Error("Error in ldapGroupMembersListToUser: ", err) return } // Query for list of uniqueMember IDs present in Edit group - readUsers, err := ldSync.ldapGroupMembersListToUser(conn, ldSync.config.ReadUserGroupCN(), sessions.UserRoleView) + readUsers, err := l.ldapGroupMembersListToUser(conn, l.config.ReadUserGroupCN(), sessions.UserRoleView) if err != nil { - ldSync.lggr.Error("Error in ldapGroupMembersListToUser: ", err) + l.lggr.Error("Error in ldapGroupMembersListToUser: ", err) return } @@ -147,9 +172,9 @@ func (ldSync *LDAPServerStateSyncer) Work() { // For each unique user in list of active sessions, check for 'Is Active' propery if defined in the config. Some LDAP providers // list group members that are no longer marked as active - usersActiveFlags, err := ldSync.validateUsersActive(dedupedEmails, conn) + usersActiveFlags, err := l.validateUsersActive(dedupedEmails, conn) if err != nil { - ldSync.lggr.Error("Error validating supplied user list: ", err) + l.lggr.Error("Error validating supplied user list: ", err) } // Remove users in the upstreamUserStateMap source of truth who are part of groups but marked as deactivated/no-active for i, active := range usersActiveFlags { @@ -160,7 +185,7 @@ func (ldSync *LDAPServerStateSyncer) Work() { // upstreamUserStateMap is now the most up to date source of truth // Now sync database sessions and roles with new data - err = sqlutil.TransactDataSource(ctx, ldSync.ds, nil, func(tx sqlutil.DataSource) error { + err = sqlutil.TransactDataSource(ctx, l.ds, nil, func(tx sqlutil.DataSource) error { // First, purge users present in the local ldap_sessions table but not in the upstream server type LDAPSession struct { UserEmail string @@ -248,36 +273,36 @@ func (ldSync *LDAPServerStateSyncer) Work() { } } - ldSync.lggr.Info("local ldap_sessions and ldap_user_api_tokens table successfully synced with upstream LDAP state") + l.lggr.Info("local ldap_sessions and ldap_user_api_tokens table successfully synced with upstream LDAP state") return nil }) if err != nil { - ldSync.lggr.Error("Error syncing local database state: ", err) + l.lggr.Error("Error syncing local database state: ", err) } - ldSync.lggr.Info("Upstream LDAP sync complete") + l.lggr.Info("Upstream LDAP sync complete") } // deleteStaleSessions deletes all ldap_sessions before the passed time. -func (ldSync *LDAPServerStateSyncer) deleteStaleSessions(ctx context.Context, before time.Time) error { - _, err := ldSync.ds.ExecContext(ctx, "DELETE FROM ldap_sessions WHERE created_at < $1", before) +func (l *LDAPServerStateSyncer) deleteStaleSessions(ctx context.Context, before time.Time) error { + _, err := l.ds.ExecContext(ctx, "DELETE FROM ldap_sessions WHERE created_at < $1", before) return err } // deleteStaleAPITokens deletes all ldap_user_api_tokens before the passed time. -func (ldSync *LDAPServerStateSyncer) deleteStaleAPITokens(ctx context.Context, before time.Time) error { - _, err := ldSync.ds.ExecContext(ctx, "DELETE FROM ldap_user_api_tokens WHERE created_at < $1", before) +func (l *LDAPServerStateSyncer) deleteStaleAPITokens(ctx context.Context, before time.Time) error { + _, err := l.ds.ExecContext(ctx, "DELETE FROM ldap_user_api_tokens WHERE created_at < $1", before) return err } // ldapGroupMembersListToUser queries the LDAP server given a conn for a list of uniqueMember who are part of the parameterized group -func (ldSync *LDAPServerStateSyncer) ldapGroupMembersListToUser(conn LDAPConn, groupNameCN string, roleToAssign sessions.UserRole) ([]sessions.User, error) { +func (l *LDAPServerStateSyncer) ldapGroupMembersListToUser(conn LDAPConn, groupNameCN string, roleToAssign sessions.UserRole) ([]sessions.User, error) { users, err := ldapGroupMembersListToUser( - conn, groupNameCN, roleToAssign, ldSync.config.GroupsDN(), - ldSync.config.BaseDN(), ldSync.config.QueryTimeout(), - ldSync.lggr, + conn, groupNameCN, roleToAssign, l.config.GroupsDN(), + l.config.BaseDN(), l.config.QueryTimeout(), + l.lggr, ) if err != nil { - ldSync.lggr.Errorf("Error listing members of group (%s): %v", groupNameCN, err) + l.lggr.Errorf("Error listing members of group (%s): %v", groupNameCN, err) return users, errors.New("error searching group members in LDAP directory") } return users, nil @@ -286,10 +311,10 @@ func (ldSync *LDAPServerStateSyncer) ldapGroupMembersListToUser(conn LDAPConn, g // validateUsersActive performs an additional LDAP server query for the supplied emails, checking the // returned user data for an 'active' property defined optionally in the config. // Returns same length bool 'valid' array, order preserved -func (ldSync *LDAPServerStateSyncer) validateUsersActive(emails []string, conn LDAPConn) ([]bool, error) { +func (l *LDAPServerStateSyncer) validateUsersActive(emails []string, conn LDAPConn) ([]bool, error) { validUsers := make([]bool, len(emails)) // If active attribute to check is not defined in config, skip - if ldSync.config.ActiveAttribute() == "" { + if l.config.ActiveAttribute() == "" { // pre fill with valids for i := range emails { validUsers[i] = true @@ -301,22 +326,22 @@ func (ldSync *LDAPServerStateSyncer) validateUsersActive(emails []string, conn L filterQuery := "(|" for _, email := range emails { escapedEmail := ldap.EscapeFilter(email) - filterQuery = fmt.Sprintf("%s(%s=%s)", filterQuery, ldSync.config.BaseUserAttr(), escapedEmail) + filterQuery = fmt.Sprintf("%s(%s=%s)", filterQuery, l.config.BaseUserAttr(), escapedEmail) } filterQuery = fmt.Sprintf("(&%s))", filterQuery) - searchBaseDN := fmt.Sprintf("%s,%s", ldSync.config.UsersDN(), ldSync.config.BaseDN()) + searchBaseDN := fmt.Sprintf("%s,%s", l.config.UsersDN(), l.config.BaseDN()) searchRequest := ldap.NewSearchRequest( searchBaseDN, ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, - 0, int(ldSync.config.QueryTimeout().Seconds()), false, + 0, int(l.config.QueryTimeout().Seconds()), false, filterQuery, - []string{ldSync.config.BaseUserAttr(), ldSync.config.ActiveAttribute()}, + []string{l.config.BaseUserAttr(), l.config.ActiveAttribute()}, nil, ) // Query LDAP server for the ActiveAttribute property of each specified user results, err := conn.Search(searchRequest) if err != nil { - ldSync.lggr.Errorf("Error searching user in LDAP query: %v", err) + l.lggr.Errorf("Error searching user in LDAP query: %v", err) return validUsers, errors.New("error searching users in LDAP directory") } // Ensure user response entries @@ -328,9 +353,9 @@ func (ldSync *LDAPServerStateSyncer) validateUsersActive(emails []string, conn L // keyed on email for final step to return flag bool list where order is preserved emailToActiveMap := make(map[string]bool) for _, result := range results.Entries { - isActiveAttribute := result.GetAttributeValue(ldSync.config.ActiveAttribute()) - uidAttribute := result.GetAttributeValue(ldSync.config.BaseUserAttr()) - emailToActiveMap[uidAttribute] = isActiveAttribute == ldSync.config.ActiveAttributeAllowedValue() + isActiveAttribute := result.GetAttributeValue(l.config.ActiveAttribute()) + uidAttribute := result.GetAttributeValue(l.config.BaseUserAttr()) + emailToActiveMap[uidAttribute] = isActiveAttribute == l.config.ActiveAttributeAllowedValue() } for i, email := range emails { active, ok := emailToActiveMap[email] diff --git a/core/sessions/localauth/reaper.go b/core/sessions/localauth/reaper.go index a3ba1693765..6f2bfe732c5 100644 --- a/core/sessions/localauth/reaper.go +++ b/core/sessions/localauth/reaper.go @@ -23,19 +23,16 @@ type SessionReaperConfig interface { // NewSessionReaper creates a reaper that cleans stale sessions from the store. func NewSessionReaper(ds sqlutil.DataSource, config SessionReaperConfig, lggr logger.Logger) *utils.SleeperTask { - return utils.NewSleeperTask(&sessionReaper{ + return utils.NewSleeperTaskCtx(&sessionReaper{ ds, config, lggr.Named("SessionReaper"), }) } -func (sr *sessionReaper) Name() string { - return "SessionReaper" -} +func (sr *sessionReaper) Name() string { return sr.lggr.Name() } -func (sr *sessionReaper) Work() { - ctx := context.Background() // TODO https://smartcontract-it.atlassian.net/browse/BCF-2887 +func (sr *sessionReaper) Work(ctx context.Context) { recordCreationStaleThreshold := sr.config.SessionReaperExpiration().Before( sr.config.SessionTimeout().Before(time.Now())) err := sr.deleteStaleSessions(ctx, recordCreationStaleThreshold) diff --git a/deployment/ccip/add_lane_test.go b/deployment/ccip/add_lane_test.go index a7618ecb712..d8443ad288b 100644 --- a/deployment/ccip/add_lane_test.go +++ b/deployment/ccip/add_lane_test.go @@ -17,6 +17,7 @@ import ( // TestAddLane covers the workflow of adding a lane between two chains and enabling it. // It also covers the case where the onRamp is disabled on the OffRamp contract initially and then enabled. func TestAddLane(t *testing.T) { + t.Parallel() // We add more chains to the chainlink nodes than the number of chains where CCIP is deployed. e := NewMemoryEnvironmentWithJobs(t, logger.TestLogger(t), 4, 4) // Here we have CR + nodes set up, but no CCIP contracts deployed. diff --git a/integration-tests/ccip-tests/actions/ccip_helpers.go b/integration-tests/ccip-tests/actions/ccip_helpers.go index a9a873f23cd..c24ae2ecd54 100644 --- a/integration-tests/ccip-tests/actions/ccip_helpers.go +++ b/integration-tests/ccip-tests/actions/ccip_helpers.go @@ -3072,7 +3072,7 @@ func (lane *CCIPLane) ExecuteManually(options ...ManualExecutionOption) error { GasLimit: big.NewInt(DefaultDestinationGasLimit), } timeNow := time.Now().UTC() - tx, err := args.ExecuteManually() + tx, err := args.ExecuteManually(lane.Context) if err != nil { return fmt.Errorf("could not execute manually: %w seqNum %d", err, seqNum) }