diff --git a/common/txmgr/mocks/tx_manager.go b/common/txmgr/mocks/tx_manager.go index a3e8c489314..974fd455903 100644 --- a/common/txmgr/mocks/tx_manager.go +++ b/common/txmgr/mocks/tx_manager.go @@ -273,9 +273,9 @@ func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) FindTx return r0, r1 } -// GetForwarderForEOA provides a mock function with given fields: eoa -func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOA(eoa ADDR) (ADDR, error) { - ret := _m.Called(eoa) +// GetForwarderForEOA provides a mock function with given fields: ctx, eoa +func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOA(ctx context.Context, eoa ADDR) (ADDR, error) { + ret := _m.Called(ctx, eoa) if len(ret) == 0 { panic("no return value specified for GetForwarderForEOA") @@ -283,17 +283,17 @@ func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetFor var r0 ADDR var r1 error - if rf, ok := ret.Get(0).(func(ADDR) (ADDR, error)); ok { - return rf(eoa) + if rf, ok := ret.Get(0).(func(context.Context, ADDR) (ADDR, error)); ok { + return rf(ctx, eoa) } - if rf, ok := ret.Get(0).(func(ADDR) ADDR); ok { - r0 = rf(eoa) + if rf, ok := ret.Get(0).(func(context.Context, ADDR) ADDR); ok { + r0 = rf(ctx, eoa) } else { r0 = ret.Get(0).(ADDR) } - if rf, ok := ret.Get(1).(func(ADDR) error); ok { - r1 = rf(eoa) + if rf, ok := ret.Get(1).(func(context.Context, ADDR) error); ok { + r1 = rf(ctx, eoa) } else { r1 = ret.Error(1) } @@ -301,9 +301,9 @@ func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetFor return r0, r1 } -// GetForwarderForEOAOCR2Feeds provides a mock function with given fields: eoa, ocr2AggregatorID -func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(eoa ADDR, ocr2AggregatorID ADDR) (ADDR, error) { - ret := _m.Called(eoa, ocr2AggregatorID) +// GetForwarderForEOAOCR2Feeds provides a mock function with given fields: ctx, eoa, ocr2AggregatorID +func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(ctx context.Context, eoa ADDR, ocr2AggregatorID ADDR) (ADDR, error) { + ret := _m.Called(ctx, eoa, ocr2AggregatorID) if len(ret) == 0 { panic("no return value specified for GetForwarderForEOAOCR2Feeds") @@ -311,17 +311,17 @@ func (_m *TxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetFor var r0 ADDR var r1 error - if rf, ok := ret.Get(0).(func(ADDR, ADDR) (ADDR, error)); ok { - return rf(eoa, ocr2AggregatorID) + if rf, ok := ret.Get(0).(func(context.Context, ADDR, ADDR) (ADDR, error)); ok { + return rf(ctx, eoa, ocr2AggregatorID) } - if rf, ok := ret.Get(0).(func(ADDR, ADDR) ADDR); ok { - r0 = rf(eoa, ocr2AggregatorID) + if rf, ok := ret.Get(0).(func(context.Context, ADDR, ADDR) ADDR); ok { + r0 = rf(ctx, eoa, ocr2AggregatorID) } else { r0 = ret.Get(0).(ADDR) } - if rf, ok := ret.Get(1).(func(ADDR, ADDR) error); ok { - r1 = rf(eoa, ocr2AggregatorID) + if rf, ok := ret.Get(1).(func(context.Context, ADDR, ADDR) error); ok { + r1 = rf(ctx, eoa, ocr2AggregatorID) } else { r1 = ret.Error(1) } diff --git a/common/txmgr/txmgr.go b/common/txmgr/txmgr.go index 1c8b59a55cc..44b518fdaab 100644 --- a/common/txmgr/txmgr.go +++ b/common/txmgr/txmgr.go @@ -46,8 +46,8 @@ type TxManager[ services.Service Trigger(addr ADDR) CreateTransaction(ctx context.Context, txRequest txmgrtypes.TxRequest[ADDR, TX_HASH]) (etx txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) - GetForwarderForEOA(eoa ADDR) (forwarder ADDR, err error) - GetForwarderForEOAOCR2Feeds(eoa, ocr2AggregatorID ADDR) (forwarder ADDR, err error) + GetForwarderForEOA(ctx context.Context, eoa ADDR) (forwarder ADDR, err error) + GetForwarderForEOAOCR2Feeds(ctx context.Context, eoa, ocr2AggregatorID ADDR) (forwarder ADDR, err error) RegisterResumeCallback(fn ResumeCallback) SendNativeToken(ctx context.Context, chainID CHAIN_ID, from, to ADDR, value big.Int, gasLimit uint64) (etx txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) Reset(addr ADDR, abandon bool) error @@ -546,20 +546,20 @@ func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) CreateTran } // Calls forwarderMgr to get a proper forwarder for a given EOA. -func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetForwarderForEOA(eoa ADDR) (forwarder ADDR, err error) { +func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetForwarderForEOA(ctx context.Context, eoa ADDR) (forwarder ADDR, err error) { if !b.txConfig.ForwardersEnabled() { return forwarder, fmt.Errorf("forwarding is not enabled, to enable set Transactions.ForwardersEnabled =true") } - forwarder, err = b.fwdMgr.ForwarderFor(eoa) + forwarder, err = b.fwdMgr.ForwarderFor(ctx, eoa) return } // GetForwarderForEOAOCR2Feeds calls forwarderMgr to get a proper forwarder for a given EOA and checks if its set as a transmitter on the OCR2Aggregator contract. -func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(eoa, ocr2Aggregator ADDR) (forwarder ADDR, err error) { +func (b *Txm[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, R, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(ctx context.Context, eoa, ocr2Aggregator ADDR) (forwarder ADDR, err error) { if !b.txConfig.ForwardersEnabled() { return forwarder, fmt.Errorf("forwarding is not enabled, to enable set Transactions.ForwardersEnabled =true") } - forwarder, err = b.fwdMgr.ForwarderForOCR2Feeds(eoa, ocr2Aggregator) + forwarder, err = b.fwdMgr.ForwarderForOCR2Feeds(ctx, eoa, ocr2Aggregator) return } @@ -656,10 +656,10 @@ func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) Tri func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) CreateTransaction(ctx context.Context, txRequest txmgrtypes.TxRequest[ADDR, TX_HASH]) (etx txmgrtypes.Tx[CHAIN_ID, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE], err error) { return etx, errors.New(n.ErrMsg) } -func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOA(addr ADDR) (fwdr ADDR, err error) { +func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOA(ctx context.Context, addr ADDR) (fwdr ADDR, err error) { return fwdr, err } -func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(_, _ ADDR) (fwdr ADDR, err error) { +func (n *NullTxManager[CHAIN_ID, HEAD, ADDR, TX_HASH, BLOCK_HASH, SEQ, FEE]) GetForwarderForEOAOCR2Feeds(ctx context.Context, _, _ ADDR) (fwdr ADDR, err error) { return fwdr, err } diff --git a/common/txmgr/types/forwarder_manager.go b/common/txmgr/types/forwarder_manager.go index 3e51ffb1524..6acb491a1fb 100644 --- a/common/txmgr/types/forwarder_manager.go +++ b/common/txmgr/types/forwarder_manager.go @@ -1,15 +1,18 @@ package types import ( + "context" + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink/v2/common/types" ) //go:generate mockery --quiet --name ForwarderManager --output ./mocks/ --case=underscore type ForwarderManager[ADDR types.Hashable] interface { services.Service - ForwarderFor(addr ADDR) (forwarder ADDR, err error) - ForwarderForOCR2Feeds(eoa, ocr2Aggregator ADDR) (forwarder ADDR, err error) + ForwarderFor(ctx context.Context, addr ADDR) (forwarder ADDR, err error) + ForwarderForOCR2Feeds(ctx context.Context, eoa, ocr2Aggregator ADDR) (forwarder ADDR, err error) // Converts payload to be forwarder-friendly ConvertPayload(dest ADDR, origPayload []byte) ([]byte, error) } diff --git a/common/txmgr/types/mocks/forwarder_manager.go b/common/txmgr/types/mocks/forwarder_manager.go index 1021e776e9d..b2cf9bc9d35 100644 --- a/common/txmgr/types/mocks/forwarder_manager.go +++ b/common/txmgr/types/mocks/forwarder_manager.go @@ -63,9 +63,9 @@ func (_m *ForwarderManager[ADDR]) ConvertPayload(dest ADDR, origPayload []byte) return r0, r1 } -// ForwarderFor provides a mock function with given fields: addr -func (_m *ForwarderManager[ADDR]) ForwarderFor(addr ADDR) (ADDR, error) { - ret := _m.Called(addr) +// ForwarderFor provides a mock function with given fields: ctx, addr +func (_m *ForwarderManager[ADDR]) ForwarderFor(ctx context.Context, addr ADDR) (ADDR, error) { + ret := _m.Called(ctx, addr) if len(ret) == 0 { panic("no return value specified for ForwarderFor") @@ -73,17 +73,17 @@ func (_m *ForwarderManager[ADDR]) ForwarderFor(addr ADDR) (ADDR, error) { var r0 ADDR var r1 error - if rf, ok := ret.Get(0).(func(ADDR) (ADDR, error)); ok { - return rf(addr) + if rf, ok := ret.Get(0).(func(context.Context, ADDR) (ADDR, error)); ok { + return rf(ctx, addr) } - if rf, ok := ret.Get(0).(func(ADDR) ADDR); ok { - r0 = rf(addr) + if rf, ok := ret.Get(0).(func(context.Context, ADDR) ADDR); ok { + r0 = rf(ctx, addr) } else { r0 = ret.Get(0).(ADDR) } - if rf, ok := ret.Get(1).(func(ADDR) error); ok { - r1 = rf(addr) + if rf, ok := ret.Get(1).(func(context.Context, ADDR) error); ok { + r1 = rf(ctx, addr) } else { r1 = ret.Error(1) } @@ -91,9 +91,9 @@ func (_m *ForwarderManager[ADDR]) ForwarderFor(addr ADDR) (ADDR, error) { return r0, r1 } -// ForwarderForOCR2Feeds provides a mock function with given fields: eoa, ocr2Aggregator -func (_m *ForwarderManager[ADDR]) ForwarderForOCR2Feeds(eoa ADDR, ocr2Aggregator ADDR) (ADDR, error) { - ret := _m.Called(eoa, ocr2Aggregator) +// ForwarderForOCR2Feeds provides a mock function with given fields: ctx, eoa, ocr2Aggregator +func (_m *ForwarderManager[ADDR]) ForwarderForOCR2Feeds(ctx context.Context, eoa ADDR, ocr2Aggregator ADDR) (ADDR, error) { + ret := _m.Called(ctx, eoa, ocr2Aggregator) if len(ret) == 0 { panic("no return value specified for ForwarderForOCR2Feeds") @@ -101,17 +101,17 @@ func (_m *ForwarderManager[ADDR]) ForwarderForOCR2Feeds(eoa ADDR, ocr2Aggregator var r0 ADDR var r1 error - if rf, ok := ret.Get(0).(func(ADDR, ADDR) (ADDR, error)); ok { - return rf(eoa, ocr2Aggregator) + if rf, ok := ret.Get(0).(func(context.Context, ADDR, ADDR) (ADDR, error)); ok { + return rf(ctx, eoa, ocr2Aggregator) } - if rf, ok := ret.Get(0).(func(ADDR, ADDR) ADDR); ok { - r0 = rf(eoa, ocr2Aggregator) + if rf, ok := ret.Get(0).(func(context.Context, ADDR, ADDR) ADDR); ok { + r0 = rf(ctx, eoa, ocr2Aggregator) } else { r0 = ret.Get(0).(ADDR) } - if rf, ok := ret.Get(1).(func(ADDR, ADDR) error); ok { - r1 = rf(eoa, ocr2Aggregator) + if rf, ok := ret.Get(1).(func(context.Context, ADDR, ADDR) error); ok { + r1 = rf(ctx, eoa, ocr2Aggregator) } else { r1 = ret.Error(1) } diff --git a/core/chains/evm/forwarders/forwarder_manager.go b/core/chains/evm/forwarders/forwarder_manager.go index 15e3534e8cb..b8035d0a62b 100644 --- a/core/chains/evm/forwarders/forwarder_manager.go +++ b/core/chains/evm/forwarders/forwarder_manager.go @@ -2,6 +2,7 @@ package forwarders import ( "context" + "errors" "slices" "sync" "time" @@ -111,9 +112,9 @@ func FilterName(addr common.Address) string { return evmlogpoller.FilterName("ForwarderManager AuthorizedSendersChanged", addr.String()) } -func (f *FwdMgr) ForwarderFor(addr common.Address) (forwarder common.Address, err error) { +func (f *FwdMgr) ForwarderFor(ctx context.Context, addr common.Address) (forwarder common.Address, err error) { // Gets forwarders for current chain. - fwdrs, err := f.ORM.FindForwardersByChain(f.ctx, big.Big(*f.evmClient.ConfiguredChainID())) + fwdrs, err := f.ORM.FindForwardersByChain(ctx, big.Big(*f.evmClient.ConfiguredChainID())) if err != nil { return common.Address{}, err } @@ -130,11 +131,14 @@ func (f *FwdMgr) ForwarderFor(addr common.Address) (forwarder common.Address, er } } } - return common.Address{}, pkgerrors.Errorf("Cannot find forwarder for given EOA") + return common.Address{}, ErrForwarderForEOANotFound } -func (f *FwdMgr) ForwarderForOCR2Feeds(eoa, ocr2Aggregator common.Address) (forwarder common.Address, err error) { - fwdrs, err := f.ORM.FindForwardersByChain(f.ctx, big.Big(*f.evmClient.ConfiguredChainID())) +// ErrForwarderForEOANotFound defines the error triggered when no valid forwarders were found for EOA +var ErrForwarderForEOANotFound = errors.New("cannot find forwarder for given EOA") + +func (f *FwdMgr) ForwarderForOCR2Feeds(ctx context.Context, eoa, ocr2Aggregator common.Address) (forwarder common.Address, err error) { + fwdrs, err := f.ORM.FindForwardersByChain(ctx, big.Big(*f.evmClient.ConfiguredChainID())) if err != nil { return common.Address{}, err } @@ -144,7 +148,7 @@ func (f *FwdMgr) ForwarderForOCR2Feeds(eoa, ocr2Aggregator common.Address) (forw return common.Address{}, err } - transmitters, err := offchainAggregator.GetTransmitters(&bind.CallOpts{Context: f.ctx}) + transmitters, err := offchainAggregator.GetTransmitters(&bind.CallOpts{Context: ctx}) if err != nil { return common.Address{}, pkgerrors.Errorf("failed to get ocr2 aggregator transmitters: %s", err.Error()) } @@ -166,7 +170,7 @@ func (f *FwdMgr) ForwarderForOCR2Feeds(eoa, ocr2Aggregator common.Address) (forw } } } - return common.Address{}, pkgerrors.Errorf("Cannot find forwarder for given EOA") + return common.Address{}, ErrForwarderForEOANotFound } func (f *FwdMgr) ConvertPayload(dest common.Address, origPayload []byte) ([]byte, error) { diff --git a/core/chains/evm/forwarders/forwarder_manager_test.go b/core/chains/evm/forwarders/forwarder_manager_test.go index 993efacac4a..be8513f5925 100644 --- a/core/chains/evm/forwarders/forwarder_manager_test.go +++ b/core/chains/evm/forwarders/forwarder_manager_test.go @@ -18,6 +18,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/sqlutil" "github.com/smartcontractkit/chainlink-common/pkg/utils" + "github.com/smartcontractkit/chainlink/v2/core/services/ocr2/testhelpers" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/client" @@ -86,7 +87,7 @@ func TestFwdMgr_MaybeForwardTransaction(t *testing.T) { require.Equal(t, lst[0].Address, forwarderAddr) require.NoError(t, fwdMgr.Start(testutils.Context(t))) - addr, err := fwdMgr.ForwarderFor(owner.From) + addr, err := fwdMgr.ForwarderFor(ctx, owner.From) require.NoError(t, err) require.Equal(t, addr.String(), forwarderAddr.String()) err = fwdMgr.Close() @@ -148,8 +149,8 @@ func TestFwdMgr_AccountUnauthorizedToForward_SkipsForwarding(t *testing.T) { err = fwdMgr.Start(testutils.Context(t)) require.NoError(t, err) - addr, err := fwdMgr.ForwarderFor(owner.From) - require.ErrorContains(t, err, "Cannot find forwarder for given EOA") + addr, err := fwdMgr.ForwarderFor(ctx, owner.From) + require.ErrorIs(t, err, forwarders.ErrForwarderForEOANotFound) require.True(t, utils.IsZero(addr)) err = fwdMgr.Close() require.NoError(t, err) @@ -214,8 +215,8 @@ func TestFwdMgr_InvalidForwarderForOCR2FeedsStates(t *testing.T) { fwdMgr = forwarders.NewFwdMgr(db, evmClient, lp, lggr, evmcfg.EVM()) require.NoError(t, fwdMgr.Start(testutils.Context(t))) // cannot find forwarder because it isn't authorized nor added as a transmitter - addr, err := fwdMgr.ForwarderForOCR2Feeds(owner.From, ocr2Address) - require.ErrorContains(t, err, "Cannot find forwarder for given EOA") + addr, err := fwdMgr.ForwarderForOCR2Feeds(ctx, owner.From, ocr2Address) + require.ErrorIs(t, err, forwarders.ErrForwarderForEOANotFound) require.True(t, utils.IsZero(addr)) _, err = forwarder.SetAuthorizedSenders(owner, []common.Address{owner.From}) @@ -227,8 +228,8 @@ func TestFwdMgr_InvalidForwarderForOCR2FeedsStates(t *testing.T) { require.Equal(t, owner.From, authorizedSenders[0]) // cannot find forwarder because it isn't added as a transmitter - addr, err = fwdMgr.ForwarderForOCR2Feeds(owner.From, ocr2Address) - require.ErrorContains(t, err, "Cannot find forwarder for given EOA") + addr, err = fwdMgr.ForwarderForOCR2Feeds(ctx, owner.From, ocr2Address) + require.ErrorIs(t, err, forwarders.ErrForwarderForEOANotFound) require.True(t, utils.IsZero(addr)) onchainConfig, err := testhelpers.GenerateDefaultOCR2OnchainConfig(big.NewInt(0), big.NewInt(10)) @@ -251,7 +252,7 @@ func TestFwdMgr_InvalidForwarderForOCR2FeedsStates(t *testing.T) { // create new fwd to have an empty cache that has to fetch authorized forwarders from log poller fwdMgr = forwarders.NewFwdMgr(db, evmClient, lp, lggr, evmcfg.EVM()) require.NoError(t, fwdMgr.Start(testutils.Context(t))) - addr, err = fwdMgr.ForwarderForOCR2Feeds(owner.From, ocr2Address) + addr, err = fwdMgr.ForwarderForOCR2Feeds(ctx, owner.From, ocr2Address) require.NoError(t, err, "forwarder should be valid and found because it is both authorized and set as a transmitter") require.Equal(t, forwarderAddr, addr) require.NoError(t, fwdMgr.Close()) diff --git a/core/services/keeper/delegate.go b/core/services/keeper/delegate.go index 71a0c5c43a9..c9d189b30c5 100644 --- a/core/services/keeper/delegate.go +++ b/core/services/keeper/delegate.go @@ -93,7 +93,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, spec job.Job) (services // In the case of forwarding, the keeper address is the forwarder contract deployed onchain between EOA and Registry. effectiveKeeperAddress := spec.KeeperSpec.FromAddress.Address() if spec.ForwardingAllowed { - fwdrAddress, fwderr := chain.TxManager().GetForwarderForEOA(spec.KeeperSpec.FromAddress.Address()) + fwdrAddress, fwderr := chain.TxManager().GetForwarderForEOA(ctx, spec.KeeperSpec.FromAddress.Address()) if fwderr == nil { effectiveKeeperAddress = fwdrAddress } else { diff --git a/core/services/keeper/integration_test.go b/core/services/keeper/integration_test.go index 9e4cf5f9041..cbbe89b3f21 100644 --- a/core/services/keeper/integration_test.go +++ b/core/services/keeper/integration_test.go @@ -417,7 +417,7 @@ func TestKeeperForwarderEthIntegration(t *testing.T) { _, err = forwarderORM.CreateForwarder(ctx, fwdrAddress, chainID) require.NoError(t, err) - addr, err := app.GetRelayers().LegacyEVMChains().Slice()[0].TxManager().GetForwarderForEOA(nodeAddress) + addr, err := app.GetRelayers().LegacyEVMChains().Slice()[0].TxManager().GetForwarderForEOA(ctx, nodeAddress) require.NoError(t, err) require.Equal(t, addr, fwdrAddress) diff --git a/core/services/ocr/delegate.go b/core/services/ocr/delegate.go index e748823ad71..a47e7ec9e7d 100644 --- a/core/services/ocr/delegate.go +++ b/core/services/ocr/delegate.go @@ -216,7 +216,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) (services [] // In the case of forwarding, the transmitter address is the forwarder contract deployed onchain between EOA and OCR contract. effectiveTransmitterAddress := concreteSpec.TransmitterAddress.Address() if jb.ForwardingAllowed { - fwdrAddress, fwderr := chain.TxManager().GetForwarderForEOA(effectiveTransmitterAddress) + fwdrAddress, fwderr := chain.TxManager().GetForwarderForEOA(ctx, effectiveTransmitterAddress) if fwderr == nil { effectiveTransmitterAddress = fwdrAddress } else { diff --git a/core/services/ocr2/delegate.go b/core/services/ocr2/delegate.go index 5d149d140f1..f1f0eba61b0 100644 --- a/core/services/ocr2/delegate.go +++ b/core/services/ocr2/delegate.go @@ -381,7 +381,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.Servi if err2 != nil { return nil, fmt.Errorf("ServicesForSpec: could not get EVM chain %s: %w", rid.ChainID, err2) } - effectiveTransmitterID, err2 = GetEVMEffectiveTransmitterID(&jb, chain, lggr) + effectiveTransmitterID, err2 = GetEVMEffectiveTransmitterID(ctx, &jb, chain, lggr) if err2 != nil { return nil, fmt.Errorf("ServicesForSpec failed to get evm transmitterID: %w", err2) } @@ -469,7 +469,7 @@ func (d *Delegate) ServicesForSpec(ctx context.Context, jb job.Job) ([]job.Servi } } -func GetEVMEffectiveTransmitterID(jb *job.Job, chain legacyevm.Chain, lggr logger.SugaredLogger) (string, error) { +func GetEVMEffectiveTransmitterID(ctx context.Context, jb *job.Job, chain legacyevm.Chain, lggr logger.SugaredLogger) (string, error) { spec := jb.OCR2OracleSpec if spec.PluginType == types.Mercury || spec.PluginType == types.LLO { return spec.TransmitterID.String, nil @@ -500,9 +500,9 @@ func GetEVMEffectiveTransmitterID(jb *job.Job, chain legacyevm.Chain, lggr logge var effectiveTransmitterID common.Address // Median forwarders need special handling because of OCR2Aggregator transmitters whitelist. if spec.PluginType == types.Median { - effectiveTransmitterID, err = chain.TxManager().GetForwarderForEOAOCR2Feeds(common.HexToAddress(spec.TransmitterID.String), common.HexToAddress(spec.ContractID)) + effectiveTransmitterID, err = chain.TxManager().GetForwarderForEOAOCR2Feeds(ctx, common.HexToAddress(spec.TransmitterID.String), common.HexToAddress(spec.ContractID)) } else { - effectiveTransmitterID, err = chain.TxManager().GetForwarderForEOA(common.HexToAddress(spec.TransmitterID.String)) + effectiveTransmitterID, err = chain.TxManager().GetForwarderForEOA(ctx, common.HexToAddress(spec.TransmitterID.String)) } if err == nil { diff --git a/core/services/ocr2/delegate_test.go b/core/services/ocr2/delegate_test.go index 8f204f57091..1e4be66c7d1 100644 --- a/core/services/ocr2/delegate_test.go +++ b/core/services/ocr2/delegate_test.go @@ -5,10 +5,12 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "gopkg.in/guregu/null.v4" "github.com/smartcontractkit/chainlink-common/pkg/types" + evmcfg "github.com/smartcontractkit/chainlink/v2/core/chains/evm/config/toml" txmmocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr/mocks" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils/big" @@ -27,7 +29,6 @@ import ( ) func TestGetEVMEffectiveTransmitterID(t *testing.T) { - ctx := testutils.Context(t) customChainID := big.New(testutils.NewRandomEVMChainID()) config := configtest.NewGeneralConfig(t, func(c *chainlink.Config, s *chainlink.Secrets) { @@ -41,7 +42,7 @@ func TestGetEVMEffectiveTransmitterID(t *testing.T) { }) db := pgtest.NewSqlxDB(t) keyStore := cltest.NewKeyStore(t, db) - require.NoError(t, keyStore.OCR2().Add(ctx, cltest.DefaultOCR2Key)) + require.NoError(t, keyStore.OCR2().Add(testutils.Context(t), cltest.DefaultOCR2Key)) lggr := logger.TestLogger(t) txManager := txmmocks.NewMockEvmTxManager(t) @@ -67,7 +68,7 @@ func TestGetEVMEffectiveTransmitterID(t *testing.T) { jb.OCR2OracleSpec.RelayConfig["sendingKeys"] = tc.sendingKeys jb.ForwardingAllowed = tc.forwardingEnabled - args := []interface{}{tc.getForwarderForEOAArg} + args := []interface{}{mock.Anything, tc.getForwarderForEOAArg} getForwarderMethodName := "GetForwarderForEOA" if tc.pluginType == types.Median { getForwarderMethodName = "GetForwarderForEOAOCR2Feeds" @@ -144,13 +145,14 @@ func TestGetEVMEffectiveTransmitterID(t *testing.T) { } t.Run("when sending keys are not defined, the first one should be set to transmitterID", func(t *testing.T) { + ctx := testutils.Context(t) jb, err := ocr2validate.ValidatedOracleSpecToml(testutils.Context(t), config.OCR2(), config.Insecure(), testspecs.GetOCR2EVMSpecMinimal(), nil) require.NoError(t, err) jb.OCR2OracleSpec.TransmitterID = null.StringFrom("some transmitterID string") jb.OCR2OracleSpec.RelayConfig["sendingKeys"] = nil chain, err := legacyChains.Get(customChainID.String()) require.NoError(t, err) - effectiveTransmitterID, err := ocr2.GetEVMEffectiveTransmitterID(&jb, chain, lggr) + effectiveTransmitterID, err := ocr2.GetEVMEffectiveTransmitterID(ctx, &jb, chain, lggr) require.NoError(t, err) require.Equal(t, "some transmitterID string", effectiveTransmitterID) require.Equal(t, []string{"some transmitterID string"}, jb.OCR2OracleSpec.RelayConfig["sendingKeys"].([]string)) @@ -158,13 +160,14 @@ func TestGetEVMEffectiveTransmitterID(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + ctx := testutils.Context(t) jb, err := ocr2validate.ValidatedOracleSpecToml(testutils.Context(t), config.OCR2(), config.Insecure(), testspecs.GetOCR2EVMSpecMinimal(), nil) require.NoError(t, err) setTestCase(&jb, tc, txManager) chain, err := legacyChains.Get(customChainID.String()) require.NoError(t, err) - effectiveTransmitterID, err := ocr2.GetEVMEffectiveTransmitterID(&jb, chain, lggr) + effectiveTransmitterID, err := ocr2.GetEVMEffectiveTransmitterID(ctx, &jb, chain, lggr) if tc.expectedError { require.Error(t, err) } else { @@ -176,18 +179,18 @@ func TestGetEVMEffectiveTransmitterID(t *testing.T) { if !jb.ForwardingAllowed { require.Equal(t, jb.OCR2OracleSpec.TransmitterID.String, effectiveTransmitterID) } - }) } t.Run("when forwarders are enabled and chain retrieval fails, error should be handled", func(t *testing.T) { + ctx := testutils.Context(t) jb, err := ocr2validate.ValidatedOracleSpecToml(testutils.Context(t), config.OCR2(), config.Insecure(), testspecs.GetOCR2EVMSpecMinimal(), nil) require.NoError(t, err) jb.ForwardingAllowed = true jb.OCR2OracleSpec.TransmitterID = null.StringFrom("0x7e57000000000000000000000000000000000001") chain, err := legacyChains.Get("not an id") require.Error(t, err) - _, err = ocr2.GetEVMEffectiveTransmitterID(&jb, chain, lggr) + _, err = ocr2.GetEVMEffectiveTransmitterID(ctx, &jb, chain, lggr) require.Error(t, err) }) } diff --git a/core/services/ocr2/plugins/ocr2keeper/integration_test.go b/core/services/ocr2/plugins/ocr2keeper/integration_test.go index 1054c59dd1c..c27a1a9dbed 100644 --- a/core/services/ocr2/plugins/ocr2keeper/integration_test.go +++ b/core/services/ocr2/plugins/ocr2keeper/integration_test.go @@ -427,7 +427,7 @@ func setupForwarderForNode( backend *backends.SimulatedBackend, recipient common.Address, linkAddr common.Address) common.Address { - + ctx := testutils.Context(t) faddr, _, authorizedForwarder, err := authorized_forwarder.DeployAuthorizedForwarder(caller, backend, linkAddr, caller.From, recipient, []byte{}) require.NoError(t, err) @@ -444,7 +444,7 @@ func setupForwarderForNode( chain, err := app.GetRelayers().LegacyEVMChains().Get((*big.Int)(&chainID).String()) require.NoError(t, err) - fwdr, err := chain.TxManager().GetForwarderForEOA(recipient) + fwdr, err := chain.TxManager().GetForwarderForEOA(ctx, recipient) require.NoError(t, err) require.Equal(t, faddr, fwdr) diff --git a/core/services/ocrcommon/transmitter.go b/core/services/ocrcommon/transmitter.go index 423db2316a7..f73b6393b9e 100644 --- a/core/services/ocrcommon/transmitter.go +++ b/core/services/ocrcommon/transmitter.go @@ -3,11 +3,13 @@ package ocrcommon import ( "context" "math/big" + "slices" "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" "github.com/smartcontractkit/chainlink/v2/common/txmgr/types" + "github.com/smartcontractkit/chainlink/v2/core/chains/evm/forwarders" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" ) @@ -64,6 +66,51 @@ func NewTransmitter( }, nil } +type txManagerOCR2 interface { + CreateTransaction(ctx context.Context, txRequest txmgr.TxRequest) (tx txmgr.Tx, err error) + GetForwarderForEOAOCR2Feeds(ctx context.Context, eoa, ocr2AggregatorID common.Address) (forwarder common.Address, err error) +} + +type ocr2FeedsTransmitter struct { + ocr2Aggregator common.Address + txManagerOCR2 + transmitter +} + +// NewOCR2FeedsTransmitter creates a new eth transmitter that handles OCR2 Feeds specific logic surrounding forwarders. +// ocr2FeedsTransmitter validates forwarders before every transmission, enabling smooth onchain config changes without job restarts. +func NewOCR2FeedsTransmitter( + txm txManagerOCR2, + fromAddresses []common.Address, + ocr2Aggregator common.Address, + gasLimit uint64, + effectiveTransmitterAddress common.Address, + strategy types.TxStrategy, + checker txmgr.TransmitCheckerSpec, + chainID *big.Int, + keystore roundRobinKeystore, +) (Transmitter, error) { + // Ensure that a keystore is provided. + if keystore == nil { + return nil, errors.New("nil keystore provided to transmitter") + } + + return &ocr2FeedsTransmitter{ + ocr2Aggregator: ocr2Aggregator, + txManagerOCR2: txm, + transmitter: transmitter{ + txm: txm, + fromAddresses: fromAddresses, + gasLimit: gasLimit, + effectiveTransmitterAddress: effectiveTransmitterAddress, + strategy: strategy, + checker: checker, + chainID: chainID, + keystore: keystore, + }, + }, nil +} + func (t *transmitter) CreateEthTransaction(ctx context.Context, toAddress common.Address, payload []byte, txMeta *txmgr.TxMeta) error { roundRobinFromAddress, err := t.keystore.GetRoundRobinAddress(ctx, t.chainID, t.fromAddresses...) @@ -96,3 +143,65 @@ func (t *transmitter) forwarderAddress() common.Address { } return t.effectiveTransmitterAddress } + +func (t *ocr2FeedsTransmitter) CreateEthTransaction(ctx context.Context, toAddress common.Address, payload []byte, txMeta *txmgr.TxMeta) error { + roundRobinFromAddress, err := t.keystore.GetRoundRobinAddress(ctx, t.chainID, t.fromAddresses...) + if err != nil { + return errors.Wrap(err, "skipped OCR transmission, error getting round-robin address") + } + + forwarderAddress, err := t.forwarderAddress(ctx, roundRobinFromAddress, toAddress) + if err != nil { + return err + } + + _, err = t.txm.CreateTransaction(ctx, txmgr.TxRequest{ + FromAddress: roundRobinFromAddress, + ToAddress: toAddress, + EncodedPayload: payload, + FeeLimit: t.gasLimit, + ForwarderAddress: forwarderAddress, + Strategy: t.strategy, + Checker: t.checker, + Meta: txMeta, + }) + + return errors.Wrap(err, "skipped OCR transmission") +} + +// FromAddress for ocr2FeedsTransmitter returns valid forwarder or effectiveTransmitterAddress if forwarders are not set. +func (t *ocr2FeedsTransmitter) FromAddress() common.Address { + roundRobinFromAddress, err := t.keystore.GetRoundRobinAddress(context.Background(), t.chainID, t.fromAddresses...) + if err != nil { + return t.effectiveTransmitterAddress + } + + forwarderAddress, err := t.GetForwarderForEOAOCR2Feeds(context.Background(), roundRobinFromAddress, t.ocr2Aggregator) + if errors.Is(err, forwarders.ErrForwarderForEOANotFound) { + // if there are no valid forwarders try to fallback to eoa + return roundRobinFromAddress + } else if err != nil { + return t.effectiveTransmitterAddress + } + + return forwarderAddress +} + +func (t *ocr2FeedsTransmitter) forwarderAddress(ctx context.Context, eoa, ocr2Aggregator common.Address) (common.Address, error) { + // If effectiveTransmitterAddress is in fromAddresses, then forwarders aren't set. + if slices.Contains(t.fromAddresses, t.effectiveTransmitterAddress) { + return common.Address{}, nil + } + + forwarderAddress, err := t.GetForwarderForEOAOCR2Feeds(ctx, eoa, ocr2Aggregator) + if err != nil { + return common.Address{}, err + } + + // if forwarder address is in fromAddresses, then none of the forwarders are valid + if slices.Contains(t.fromAddresses, forwarderAddress) { + forwarderAddress = common.Address{} + } + + return forwarderAddress, nil +} diff --git a/core/services/pipeline/task.eth_tx.go b/core/services/pipeline/task.eth_tx.go index 354651acbb4..964591cacd2 100644 --- a/core/services/pipeline/task.eth_tx.go +++ b/core/services/pipeline/task.eth_tx.go @@ -140,7 +140,7 @@ func (t *ETHTxTask) Run(ctx context.Context, lggr logger.Logger, vars Vars, inpu var forwarderAddress common.Address if t.forwardingAllowed { var fwderr error - forwarderAddress, fwderr = chain.TxManager().GetForwarderForEOA(fromAddr) + forwarderAddress, fwderr = chain.TxManager().GetForwarderForEOA(ctx, fromAddr) if fwderr != nil { lggr.Warnw("Skipping forwarding for job, will fallback to default behavior", "err", fwderr) } diff --git a/core/services/relay/evm/evm.go b/core/services/relay/evm/evm.go index 585d20df3ab..62a94ae8f3d 100644 --- a/core/services/relay/evm/evm.go +++ b/core/services/relay/evm/evm.go @@ -526,17 +526,34 @@ func newOnChainContractTransmitter(ctx context.Context, lggr logger.Logger, rarg gasLimit = uint64(*opts.pluginGasLimit) } - transmitter, err := ocrcommon.NewTransmitter( - configWatcher.chain.TxManager(), - fromAddresses, - gasLimit, - effectiveTransmitterAddress, - strategy, - checker, - configWatcher.chain.ID(), - ethKeystore, - ) + var transmitter Transmitter + var err error + switch commontypes.OCR2PluginType(rargs.ProviderType) { + case commontypes.Median: + transmitter, err = ocrcommon.NewOCR2FeedsTransmitter( + configWatcher.chain.TxManager(), + fromAddresses, + common.HexToAddress(rargs.ContractID), + gasLimit, + effectiveTransmitterAddress, + strategy, + checker, + configWatcher.chain.ID(), + ethKeystore, + ) + default: + transmitter, err = ocrcommon.NewTransmitter( + configWatcher.chain.TxManager(), + fromAddresses, + gasLimit, + effectiveTransmitterAddress, + strategy, + checker, + configWatcher.chain.ID(), + ethKeystore, + ) + } if err != nil { return nil, pkgerrors.Wrap(err, "failed to create transmitter") }