diff --git a/core/internal/features/ocr2/features_ocr2_helper.go b/core/internal/features/ocr2/features_ocr2_helper.go index 9287d0df5b1..14b465ec6be 100644 --- a/core/internal/features/ocr2/features_ocr2_helper.go +++ b/core/internal/features/ocr2/features_ocr2_helper.go @@ -610,7 +610,7 @@ updateInterval = "1m" contractABI, err2 := abi.JSON(strings.NewReader(ocr2aggregator.OCR2AggregatorABI)) require.NoError(t, err2) apps[0].GetRelayers().LegacyEVMChains().Slice() - ct, err2 := evm.NewOCRContractTransmitter(testutils.Context(t), ocrContractAddress, apps[0].GetRelayers().LegacyEVMChains().Slice()[0].Client(), contractABI, nil, apps[0].GetRelayers().LegacyEVMChains().Slice()[0].LogPoller(), lggr) + ct, err2 := evm.NewOCRContractTransmitter(testutils.Context(t), ocrContractAddress, apps[0].GetRelayers().LegacyEVMChains().Slice()[0].Client(), contractABI, nil, apps[0].GetRelayers().LegacyEVMChains().Slice()[0].LogPoller(), lggr, apps[0].KeyStore.Eth()) require.NoError(t, err2) configDigest, epoch, err2 := ct.LatestConfigDigestAndEpoch(testutils.Context(t)) require.NoError(t, err2) diff --git a/core/internal/features/ocr2/features_ocr2_test.go b/core/internal/features/ocr2/features_ocr2_test.go index 01c269d19e3..6578a4a9aff 100644 --- a/core/internal/features/ocr2/features_ocr2_test.go +++ b/core/internal/features/ocr2/features_ocr2_test.go @@ -264,7 +264,7 @@ updateInterval = "1m" // Assert we can read the latest config digest and epoch after a report has been submitted. contractABI, err := abi.JSON(strings.NewReader(ocr2aggregator.OCR2AggregatorABI)) require.NoError(t, err) - ct, err := evm.NewOCRContractTransmitter(testutils.Context(t), ocrContractAddress, apps[0].GetRelayers().LegacyEVMChains().Slice()[0].Client(), contractABI, nil, apps[0].GetRelayers().LegacyEVMChains().Slice()[0].LogPoller(), lggr) + ct, err := evm.NewOCRContractTransmitter(testutils.Context(t), ocrContractAddress, apps[0].GetRelayers().LegacyEVMChains().Slice()[0].Client(), contractABI, nil, apps[0].GetRelayers().LegacyEVMChains().Slice()[0].LogPoller(), lggr, apps[0].KeyStore.Eth()) require.NoError(t, err) configDigest, epoch, err := ct.LatestConfigDigestAndEpoch(testutils.Context(t)) require.NoError(t, err) diff --git a/core/services/job/job_orm_test.go b/core/services/job/job_orm_test.go index 6e7a0b2a034..b05c17ac059 100644 --- a/core/services/job/job_orm_test.go +++ b/core/services/job/job_orm_test.go @@ -19,6 +19,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/utils/jsonserializable" pkgworkflows "github.com/smartcontractkit/chainlink-common/pkg/workflows" + "github.com/smartcontractkit/chainlink/v2/core/services/keystore" "github.com/smartcontractkit/chainlink/v2/core/bridges" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/assets" @@ -2233,3 +2234,108 @@ func TestORM_CreateJob_OCR2_With_DualTransmission(t *testing.T) { keyStore.Eth().XXXTestingOnlyAdd(ctx, dtTransmitterAddress) require.NoError(t, jobORM.CreateJob(ctx, &jb)) } + +func TestORM_CreateJob_KeyLocking(t *testing.T) { + ctx := testutils.Context(t) + customChainID := big.New(testutils.NewRandomEVMChainID()) + + config := configtest.NewGeneralConfig(t, func(c *chainlink.Config, s *chainlink.Secrets) { + enabled := true + c.EVM = append(c.EVM, &evmcfg.EVMConfig{ + ChainID: customChainID, + Chain: evmcfg.Defaults(customChainID), + Enabled: &enabled, + Nodes: evmcfg.EVMNodes{{}}, + }) + }) + db := pgtest.NewSqlxDB(t) + ks := cltest.NewKeyStore(t, db) + require.NoError(t, ks.OCR2().Add(ctx, cltest.DefaultOCR2Key)) + _, transmitterID := cltest.MustInsertRandomKey(t, ks.Eth()) + dtTransmitterAddress := cltest.MustGenerateRandomKey(t) + ks.Eth().XXXTestingOnlyAdd(ctx, dtTransmitterAddress) + + baseJobSpec := fmt.Sprintf(testspecs.OCR2EVMDualTransmissionSpecMinimalTemplate, transmitterID.String()) + + lggr := logger.TestLogger(t) + pipelineORM := pipeline.NewORM(db, lggr, config.JobPipeline().MaxSuccessfulRuns()) + bridgesORM := bridges.NewORM(db) + + jobORM := NewTestORM(t, db, pipelineORM, bridgesORM, ks) + + t.Run("keys not locked", func(t *testing.T) { + completeDualTransmissionSpec := fmt.Sprintf(` + enableDualTransmission=true + [relayConfig.dualTransmission] + contractAddress = '0x613a38AC1659769640aaE063C651F48E0250454C' + transmitterAddress = '%s' + [relayConfig.dualTransmission.meta] + key1 = ['val1'] + key2 = ['val2','val3'] + `, + dtTransmitterAddress.Address.String()) + + jb, err := ocr2validate.ValidatedOracleSpecToml(testutils.Context(t), config.OCR2(), config.Insecure(), baseJobSpec+completeDualTransmissionSpec, nil) + require.NoError(t, err) + + jb.OCR2OracleSpec.TransmitterID = null.StringFrom(transmitterID.String()) + jb.Name = null.StringFrom(uuid.NewString()) + + require.NoError(t, jobORM.CreateJob(ctx, &jb)) + }) + + t.Run("keys locked", func(t *testing.T) { + completeDualTransmissionSpec := fmt.Sprintf(` + enableDualTransmission=true + [relayConfig.dualTransmission] + contractAddress = '0x613a38AC1659769640aaE063C651F48E0250454C' + transmitterAddress = '%s' + [relayConfig.dualTransmission.meta] + key1 = ['val1'] + key2 = ['val2','val3'] + `, + dtTransmitterAddress.Address.String()) + + rm, err := ks.Eth().GetResourceMutex(ctx, transmitterID) + require.NoError(t, err) + require.NoError(t, rm.TryLock(keystore.TXMv1)) + rm, err = ks.Eth().GetResourceMutex(ctx, dtTransmitterAddress.Address) + require.NoError(t, err) + require.NoError(t, rm.TryLock(keystore.TXMv2)) + jb, err := ocr2validate.ValidatedOracleSpecToml(testutils.Context(t), config.OCR2(), config.Insecure(), baseJobSpec+completeDualTransmissionSpec, nil) + require.NoError(t, err) + + jb.OCR2OracleSpec.TransmitterID = null.StringFrom(transmitterID.String()) + jb.Name = null.StringFrom(uuid.NewString()) + + require.NoError(t, jobORM.CreateJob(ctx, &jb)) + }) + + t.Run("keys locked but job spec misconfigured", func(t *testing.T) { + rm, err := ks.Eth().GetResourceMutex(ctx, transmitterID) + require.NoError(t, err) + require.NoError(t, rm.TryLock(keystore.TXMv1)) + rm, err = ks.Eth().GetResourceMutex(ctx, dtTransmitterAddress.Address) + require.NoError(t, err) + require.NoError(t, rm.TryLock(keystore.TXMv2)) + + completeDualTransmissionSpec := fmt.Sprintf(` + enableDualTransmission=true + [relayConfig.dualTransmission] + contractAddress = '0x613a38AC1659769640aaE063C651F48E0250454C' + transmitterAddress = '%s' + [relayConfig.dualTransmission.meta] + key1 = ['val1'] + key2 = ['val2','val3'] + `, + transmitterID.String()) + + jb, err := ocr2validate.ValidatedOracleSpecToml(testutils.Context(t), config.OCR2(), config.Insecure(), baseJobSpec+completeDualTransmissionSpec, nil) + require.NoError(t, err) + + jb.OCR2OracleSpec.TransmitterID = null.StringFrom(dtTransmitterAddress.Address.String()) + jb.Name = null.StringFrom(uuid.NewString()) + + require.ErrorContains(t, jobORM.CreateJob(ctx, &jb), "cannot be a secondary transmitter address because it's used a primary transmitter in another job") + }) +} diff --git a/core/services/job/orm.go b/core/services/job/orm.go index 62cc2cd596a..0d944dcfdb7 100644 --- a/core/services/job/orm.go +++ b/core/services/job/orm.go @@ -308,6 +308,10 @@ func (o *orm) CreateJob(ctx context.Context, jb *Job) error { } if enableDualTransmission, ok := jb.OCR2OracleSpec.RelayConfig["enableDualTransmission"]; ok && enableDualTransmission != nil { + if jb.OCR2OracleSpec.Relay != relay.NetworkEVM { + return errors.New("dual transmission is enabled only for EVM") + } + rawDualTransmissionConfig, ok := jb.OCR2OracleSpec.RelayConfig["dualTransmission"] if !ok { return errors.New("dual transmission is enabled but no dual transmission config present") @@ -341,6 +345,23 @@ func (o *orm) CreateJob(ctx context.Context, jb *Job) error { return errors.Wrap(err, "unknown dual transmission transmitterAddress") } + // Check if secondary transmitter address is used as primary somewhere else + hasLock, err2 := checkIfKeyHasLock(ctx, tx.keyStore.Eth(), common.HexToAddress(dtTransmitterAddress), keystore.TXMv1) + if err2 != nil { + return err2 + } else if hasLock { + return errors.Errorf("key %s cannot be a secondary transmitter address because it's used a primary transmitter in another job", dtTransmitterAddress) + } + } + + // Check if primary transmitter address is used as secondary somewhere else, don't check for mercury as it uses CSA keys for transmitters + if jb.OCR2OracleSpec.PluginType != types.Mercury { + hasLock, err2 := checkIfKeyHasLock(ctx, tx.keyStore.Eth(), common.HexToAddress(jb.OCR2OracleSpec.TransmitterID.String), keystore.TXMv2) + if err2 != nil { + return err2 + } else if hasLock { + return errors.Errorf("key %s cannot be a (primary) transmitter address because it's used a secondary transmitter address in another job", jb.OCR2OracleSpec.TransmitterID.String) + } } specID, err := tx.insertOCR2OracleSpec(ctx, jb.OCR2OracleSpec) @@ -1745,3 +1766,12 @@ func validateDualTransmissionMeta(meta map[string]interface{}) error { return nil } + +func checkIfKeyHasLock(ctx context.Context, ks keystore.Eth, address common.Address, usage keystore.ServiceType) (bool, error) { + rm, err := ks.GetResourceMutex(ctx, address) + if err != nil { + return false, err + } + + return rm.IsLocked(usage) +} diff --git a/core/services/keystore/eth.go b/core/services/keystore/eth.go index f69bbec28d2..8d5e5bf78be 100644 --- a/core/services/keystore/eth.go +++ b/core/services/keystore/eth.go @@ -48,6 +48,7 @@ type Eth interface { GetStateForKey(ctx context.Context, key ethkey.KeyV2) (ethkey.State, error) GetStatesForChain(ctx context.Context, chainID *big.Int) ([]ethkey.State, error) EnabledAddressesForChain(ctx context.Context, chainID *big.Int) (addresses []common.Address, err error) + GetResourceMutex(ctx context.Context, address common.Address) (*ResourceMutex, error) XXXTestingOnlySetState(ctx context.Context, keyState ethkey.State) XXXTestingOnlyAdd(ctx context.Context, key ethkey.KeyV2) @@ -59,6 +60,26 @@ type eth struct { ds sqlutil.DataSource subscribers [](chan struct{}) subscribersMu *sync.RWMutex + resourceMutex map[common.Address]*ResourceMutex // ResourceMutex is an internal field and ought not be persisted to the database. Its main usage is to verify that the same key is not used for both TXMv1 and TXMv2 (usage in both TXMs will cause nonce drift and will lead to missing transactions). This functionality should be removed after we completely switch to TXMv2 +} + +// GetResourceMutex gets the resource mutex associates with the address if no resource mutex is found a new one is created +func (ks *eth) GetResourceMutex(ctx context.Context, address common.Address) (*ResourceMutex, error) { + ks.lock.Lock() + defer ks.lock.Unlock() + if ks.isLocked() { + return nil, ErrLocked + } + + if ks.resourceMutex == nil { + ks.resourceMutex = make(map[common.Address]*ResourceMutex) + } + + _, exists := ks.resourceMutex[address] + if !exists { + ks.resourceMutex[address] = NewResourceMutex() + } + return ks.resourceMutex[address], nil } var _ Eth = ð{} diff --git a/core/services/keystore/mocks/eth.go b/core/services/keystore/mocks/eth.go index 7ed960663f7..44892fa3059 100644 --- a/core/services/keystore/mocks/eth.go +++ b/core/services/keystore/mocks/eth.go @@ -10,6 +10,8 @@ import ( ethkey "github.com/smartcontractkit/chainlink/v2/core/services/keystore/keys/ethkey" + keystore "github.com/smartcontractkit/chainlink/v2/core/services/keystore" + mock "github.com/stretchr/testify/mock" types "github.com/ethereum/go-ethereum/core/types" @@ -702,6 +704,65 @@ func (_c *Eth_GetAll_Call) RunAndReturn(run func(context.Context) ([]ethkey.KeyV return _c } +// GetResourceMutex provides a mock function with given fields: ctx, address +func (_m *Eth) GetResourceMutex(ctx context.Context, address common.Address) (*keystore.ResourceMutex, error) { + ret := _m.Called(ctx, address) + + if len(ret) == 0 { + panic("no return value specified for GetResourceMutex") + } + + var r0 *keystore.ResourceMutex + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, common.Address) (*keystore.ResourceMutex, error)); ok { + return rf(ctx, address) + } + if rf, ok := ret.Get(0).(func(context.Context, common.Address) *keystore.ResourceMutex); ok { + r0 = rf(ctx, address) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*keystore.ResourceMutex) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, common.Address) error); ok { + r1 = rf(ctx, address) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Eth_GetResourceMutex_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetResourceMutex' +type Eth_GetResourceMutex_Call struct { + *mock.Call +} + +// GetResourceMutex is a helper method to define mock.On call +// - ctx context.Context +// - address common.Address +func (_e *Eth_Expecter) GetResourceMutex(ctx interface{}, address interface{}) *Eth_GetResourceMutex_Call { + return &Eth_GetResourceMutex_Call{Call: _e.mock.On("GetResourceMutex", ctx, address)} +} + +func (_c *Eth_GetResourceMutex_Call) Run(run func(ctx context.Context, address common.Address)) *Eth_GetResourceMutex_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(common.Address)) + }) + return _c +} + +func (_c *Eth_GetResourceMutex_Call) Return(_a0 *keystore.ResourceMutex, _a1 error) *Eth_GetResourceMutex_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Eth_GetResourceMutex_Call) RunAndReturn(run func(context.Context, common.Address) (*keystore.ResourceMutex, error)) *Eth_GetResourceMutex_Call { + _c.Call.Return(run) + return _c +} + // GetRoundRobinAddress provides a mock function with given fields: ctx, chainID, addresses func (_m *Eth) GetRoundRobinAddress(ctx context.Context, chainID *big.Int, addresses ...common.Address) (common.Address, error) { _va := make([]interface{}, len(addresses)) diff --git a/core/services/keystore/models.go b/core/services/keystore/models.go index 1ebc7480997..55a1c847816 100644 --- a/core/services/keystore/models.go +++ b/core/services/keystore/models.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "math/big" + "sync" "time" gethkeystore "github.com/ethereum/go-ethereum/accounts/keystore" @@ -423,3 +424,70 @@ func (rawKeys rawKeyRing) keys() (*keyRing, error) { func adulteratedPassword(password string) string { return "master-password-" + password } + +type ResourceMutex struct { + mu sync.Mutex + serviceType ServiceType + count int // Tracks active users per service type +} +type ServiceType int + +const ( + TXMv1 ServiceType = iota + TXMv2 +) + +// TryLock attempts to lock the resource for the specified service type. +// It returns an error if the resource is locked by a different service type. +func (rm *ResourceMutex) TryLock(serviceType ServiceType) error { + rm.mu.Lock() + defer rm.mu.Unlock() + + if rm.count == 0 { + rm.serviceType = serviceType + } + + // Check if other service types are using the resource + if rm.serviceType != serviceType && rm.count > 0 { + return errors.New("resource is locked by another service type") + } + + // Increment active count for the current service type + rm.count++ + return nil +} + +// Unlock releases the lock for the service type +func (rm *ResourceMutex) Unlock(serviceType ServiceType) error { + rm.mu.Lock() + defer rm.mu.Unlock() + + // Check if the service type has an active lock + if rm.count == 0 { + return errors.New("no active lock") + } + + if rm.serviceType != serviceType { + return errors.New("no active lock for this service type") + } + + // Decrement active count for the service type + rm.count-- + return nil +} + +// IsLocked checks if the resource is locked by a specific service type. +func (rm *ResourceMutex) IsLocked(serviceType ServiceType) (bool, error) { + rm.mu.Lock() + defer rm.mu.Unlock() + + if rm.count == 0 || rm.serviceType != serviceType { + return false, nil + } + + return true, nil +} + +func NewResourceMutex() *ResourceMutex { + return &ResourceMutex{} +} diff --git a/core/services/keystore/models_test.go b/core/services/keystore/models_test.go index a66e29865d1..e4a2058b569 100644 --- a/core/services/keystore/models_test.go +++ b/core/services/keystore/models_test.go @@ -161,3 +161,77 @@ func TestKeyRing_Encrypt_Decrypt(t *testing.T) { require.Error(t, err) }) } + +func TestResourceMutex_LockUnlock(t *testing.T) { + rm := &ResourceMutex{} + + err := rm.TryLock(TXMv1) + require.NoError(t, err) + + err = rm.Unlock(TXMv1) + require.NoError(t, err) +} + +func TestResourceMutex_LockByDifferentServiceType(t *testing.T) { + rm := &ResourceMutex{} + + err := rm.TryLock(TXMv1) + require.NoError(t, err) + + err = rm.TryLock(TXMv2) + require.Error(t, err) + require.Equal(t, "resource is locked by another service type", err.Error()) +} + +func TestResourceMutex_UnlockWithoutLock(t *testing.T) { + rm := &ResourceMutex{} + + err := rm.Unlock(TXMv1) + require.Error(t, err) + require.Equal(t, "no active lock", err.Error()) + + require.NoError(t, rm.TryLock(TXMv1)) + err = rm.Unlock(TXMv2) + require.Error(t, err) + require.Equal(t, "no active lock for this service type", err.Error()) +} + +func TestResourceMutex_MultipleLocks(t *testing.T) { + rm := &ResourceMutex{} + + err := rm.TryLock(TXMv1) + require.NoError(t, err) + + err = rm.TryLock(TXMv1) + require.NoError(t, err) + + err = rm.Unlock(TXMv1) + require.NoError(t, err) + + err = rm.Unlock(TXMv1) + require.NoError(t, err) +} + +func TestIsLocked_WhenResourceIsLockedByServiceType(t *testing.T) { + rm := &ResourceMutex{serviceType: TXMv1, count: 1} + + locked, err := rm.IsLocked(TXMv1) + require.NoError(t, err) + require.True(t, locked) +} + +func TestIsLocked_WhenResourceIsNotLockedByServiceType(t *testing.T) { + rm := &ResourceMutex{} + + locked, err := rm.IsLocked(TXMv1) + require.NoError(t, err) + require.False(t, locked) +} + +func TestIsLocked_WhenResourceIsLockedByDifferentServiceType(t *testing.T) { + rm := &ResourceMutex{serviceType: TXMv2, count: 1} + + locked, err := rm.IsLocked(TXMv1) + require.NoError(t, err) + require.False(t, locked) +} diff --git a/core/services/ocr2/plugins/ccip/transmitter/transmitter.go b/core/services/ocr2/plugins/ccip/transmitter/transmitter.go index abb023a4251..118788cd95f 100644 --- a/core/services/ocr2/plugins/ccip/transmitter/transmitter.go +++ b/core/services/ocr2/plugins/ccip/transmitter/transmitter.go @@ -28,6 +28,7 @@ type Transmitter interface { FromAddress(context.Context) common.Address CreateSecondaryEthTransaction(context.Context, []byte, *txmgr.TxMeta) error + SecondaryFromAddress(context.Context) (common.Address, error) } type transmitter struct { @@ -147,3 +148,7 @@ func (t *transmitter) forwarderAddress() common.Address { func (t *transmitter) CreateSecondaryEthTransaction(ctx context.Context, bytes []byte, meta *txmgr.TxMeta) error { return errors.New("trying to send a secondary transmission on a non dual transmitter") } + +func (t *transmitter) SecondaryFromAddress(ctx context.Context) (common.Address, error) { + return common.Address{}, errors.New("trying to get secondary address on a non dual transmitter") +} diff --git a/core/services/ocrcommon/dual_transmitter.go b/core/services/ocrcommon/dual_transmitter.go index efc60978f19..b07ca0a9b9b 100644 --- a/core/services/ocrcommon/dual_transmitter.go +++ b/core/services/ocrcommon/dual_transmitter.go @@ -123,6 +123,10 @@ func (t *ocr2FeedsDualTransmission) FromAddress(ctx context.Context) common.Addr return forwarderAddress } +func (t *ocr2FeedsDualTransmission) SecondaryFromAddress(ctx context.Context) (common.Address, error) { + return t.secondaryFromAddress, nil +} + func (t *ocr2FeedsDualTransmission) urlParams() string { values := url.Values{} for k, v := range t.secondaryMeta { diff --git a/core/services/ocrcommon/transmitter.go b/core/services/ocrcommon/transmitter.go index 01200bbb7cb..d38b765a792 100644 --- a/core/services/ocrcommon/transmitter.go +++ b/core/services/ocrcommon/transmitter.go @@ -11,6 +11,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/common/txmgr/types" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/forwarders" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" + "github.com/smartcontractkit/chainlink/v2/core/services/keystore" types2 "github.com/smartcontractkit/chainlink/v2/core/services/relay/evm/types" ) @@ -27,6 +28,7 @@ type Transmitter interface { FromAddress(context.Context) common.Address CreateSecondaryEthTransaction(context.Context, []byte, *txmgr.TxMeta) error + SecondaryFromAddress(context.Context) (common.Address, error) } type transmitter struct { @@ -82,6 +84,7 @@ type ocr2FeedsTransmitter struct { // NewOCR2FeedsTransmitter creates a new eth transmitter that handles OCR2 Feeds specific logic surrounding forwarders. // ocr2FeedsTransmitter validates forwarders before every transmission, enabling smooth onchain config changes without job restarts. func NewOCR2FeedsTransmitter( + ctx context.Context, txm txManagerOCR2, fromAddresses []common.Address, ocr2Aggregator common.Address, @@ -90,15 +93,26 @@ func NewOCR2FeedsTransmitter( strategy types.TxStrategy, checker txmgr.TransmitCheckerSpec, chainID *big.Int, - keystore roundRobinKeystore, + ks keystore.Eth, dualTransmissionConfig *types2.DualTransmissionConfig, ) (Transmitter, error) { // Ensure that a keystore is provided. - if keystore == nil { + if ks == nil { return nil, errors.New("nil keystore provided to transmitter") } + if hasLock, err := keyHasLock(ctx, ks, effectiveTransmitterAddress, keystore.TXMv2); err != nil { + return nil, err + } else if hasLock { + return nil, errors.Errorf("key %s is used as a secondary transmitter in another job. primary and secondary transmitters cannot be mixed", effectiveTransmitterAddress.String()) + } + if dualTransmissionConfig != nil { + if hasLock, err := keyHasLock(ctx, ks, dualTransmissionConfig.TransmitterAddress, keystore.TXMv1); err != nil { + return nil, err + } else if hasLock { + return nil, errors.Errorf("key %s is used as a primary transmitter in another job. primary and secondary transmitters cannot be mixed", effectiveTransmitterAddress.String()) + } return &ocr2FeedsDualTransmission{ ocr2Aggregator: ocr2Aggregator, txm: txm, @@ -109,7 +123,7 @@ func NewOCR2FeedsTransmitter( strategy: strategy, checker: checker, chainID: chainID, - keystore: keystore, + keystore: ks, secondaryContractAddress: dualTransmissionConfig.ContractAddress, secondaryFromAddress: dualTransmissionConfig.TransmitterAddress, secondaryMeta: dualTransmissionConfig.Meta, @@ -126,7 +140,7 @@ func NewOCR2FeedsTransmitter( strategy: strategy, checker: checker, chainID: chainID, - keystore: keystore, + keystore: ks, }, }, nil } @@ -153,6 +167,9 @@ func (t *transmitter) CreateEthTransaction(ctx context.Context, toAddress common func (t *transmitter) CreateSecondaryEthTransaction(ctx context.Context, bytes []byte, meta *txmgr.TxMeta) error { return errors.New("trying to send a secondary transmission on a non dual transmitter") } +func (t *transmitter) SecondaryFromAddress(ctx context.Context) (common.Address, error) { + return common.Address{}, errors.New("trying to get secondary address on a non dual transmitter") +} func (t *transmitter) FromAddress(context.Context) common.Address { return t.effectiveTransmitterAddress @@ -232,3 +249,12 @@ func (t *ocr2FeedsTransmitter) forwarderAddress(ctx context.Context, eoa, ocr2Ag func (t *ocr2FeedsTransmitter) CreateSecondaryEthTransaction(ctx context.Context, bytes []byte, meta *txmgr.TxMeta) error { return errors.New("trying to send a secondary transmission on a non dual transmitter") } + +func keyHasLock(ctx context.Context, ks keystore.Eth, address common.Address, service keystore.ServiceType) (bool, error) { + rm, err := ks.GetResourceMutex(ctx, address) + if err != nil { + return false, err + } + + return rm.IsLocked(service) +} diff --git a/core/services/ocrcommon/transmitter_test.go b/core/services/ocrcommon/transmitter_test.go index bb91a87d517..95787cee8ad 100644 --- a/core/services/ocrcommon/transmitter_test.go +++ b/core/services/ocrcommon/transmitter_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" commontxmmocks "github.com/smartcontractkit/chainlink/v2/common/txmgr/types/mocks" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" txmmocks "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr/mocks" @@ -178,6 +179,7 @@ func Test_DualTransmitter(t *testing.T) { db := pgtest.NewSqlxDB(t) ethKeyStore := cltest.NewKeyStore(t, db).Eth() + ctx := tests.Context(t) _, fromAddress := cltest.MustInsertRandomKey(t, ethKeyStore) _, secondaryFromAddress := cltest.MustInsertRandomKey(t, ethKeyStore) @@ -203,6 +205,7 @@ func Test_DualTransmitter(t *testing.T) { } transmitter, err := ocrcommon.NewOCR2FeedsTransmitter( + ctx, txm, []common.Address{fromAddress}, contractAddress, diff --git a/core/services/relay/evm/contract_transmitter.go b/core/services/relay/evm/contract_transmitter.go index 248968ec053..8aae2edf2b5 100644 --- a/core/services/relay/evm/contract_transmitter.go +++ b/core/services/relay/evm/contract_transmitter.go @@ -17,11 +17,11 @@ import ( ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/utils" "github.com/smartcontractkit/chainlink/v2/core/services" + "github.com/smartcontractkit/chainlink/v2/core/services/keystore" ) type ContractTransmitter interface { @@ -35,7 +35,9 @@ type Transmitter interface { CreateEthTransaction(ctx context.Context, toAddress gethcommon.Address, payload []byte, txMeta *txmgr.TxMeta) error FromAddress(context.Context) gethcommon.Address + // Dual transmission CreateSecondaryEthTransaction(ctx context.Context, payload []byte, txMeta *txmgr.TxMeta) error + SecondaryFromAddress(context.Context) (gethcommon.Address, error) } type ReportToEthMetadata func([]byte) (*txmgr.TxMeta, error) @@ -87,6 +89,7 @@ type contractTransmitter struct { contractReader contractReader lp logpoller.LogPoller lggr logger.Logger + ks keystore.Eth // Options transmitterOptions *transmitterOps } @@ -103,6 +106,7 @@ func NewOCRContractTransmitter( transmitter Transmitter, lp logpoller.LogPoller, lggr logger.Logger, + ks keystore.Eth, opts ...OCRTransmitterOption, ) (*contractTransmitter, error) { transmitted, ok := contractABI.Events["Transmitted"] @@ -118,6 +122,7 @@ func NewOCRContractTransmitter( lp: lp, contractReader: caller, lggr: logger.Named(lggr, "OCRContractTransmitter"), + ks: ks, transmitterOptions: &transmitterOps{ reportToEvmTxMeta: reportToEvmTxMetaNoop, excludeSigs: false, @@ -246,8 +251,22 @@ func (oc *contractTransmitter) FromAccount(ctx context.Context) (ocrtypes.Accoun return ocrtypes.Account(oc.transmitter.FromAddress(ctx).String()), nil } -func (oc *contractTransmitter) Start(ctx context.Context) error { return nil } -func (oc *contractTransmitter) Close() error { return nil } +func (oc *contractTransmitter) Start(ctx context.Context) error { + // Lock the transmitters to TXMv1 + rm, err := oc.ks.GetResourceMutex(ctx, oc.transmitter.FromAddress(ctx)) + if err != nil { + return err + } + return rm.TryLock(keystore.TXMv1) +} +func (oc *contractTransmitter) Close() error { + // Unlock the transmitters to TXMv1 + rm, err := oc.ks.GetResourceMutex(context.Background(), oc.transmitter.FromAddress(context.Background())) + if err != nil { + return err + } + return rm.Unlock(keystore.TXMv1) +} // Has no state/lifecycle so it's always healthy and ready func (oc *contractTransmitter) Ready() error { return nil } diff --git a/core/services/relay/evm/contract_transmitter_test.go b/core/services/relay/evm/contract_transmitter_test.go index 6106389f326..3f00f647bc0 100644 --- a/core/services/relay/evm/contract_transmitter_test.go +++ b/core/services/relay/evm/contract_transmitter_test.go @@ -19,6 +19,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/keystore/mocks" "github.com/smartcontractkit/libocr/commontypes" "github.com/smartcontractkit/libocr/gethwrappers2/ocr2aggregator" @@ -34,6 +35,10 @@ type mockTransmitter struct { lastPayload []byte } +func (m *mockTransmitter) SecondaryFromAddress(ctx context.Context) (gethcommon.Address, error) { + return gethcommon.Address{}, nil +} + func (m *mockTransmitter) CreateSecondaryEthTransaction(ctx context.Context, bytes []byte, meta *txmgr.TxMeta) error { return nil } @@ -52,6 +57,7 @@ func TestContractTransmitter(t *testing.T) { c := evmclimocks.NewClient(t) lp := lpmocks.NewLogPoller(t) ctx := testutils.Context(t) + ks := mocks.NewEth(t) // scanLogs = false digestAndEpochDontScanLogs, _ := hex.DecodeString( "0000000000000000000000000000000000000000000000000000000000000000" + // false @@ -63,7 +69,7 @@ func TestContractTransmitter(t *testing.T) { reportToEvmTxMeta := func(b []byte) (*txmgr.TxMeta, error) { return &txmgr.TxMeta{}, nil } - ot, err := NewOCRContractTransmitter(ctx, gethcommon.Address{}, c, contractABI, &mockTransmitter{}, lp, lggr, + ot, err := NewOCRContractTransmitter(ctx, gethcommon.Address{}, c, contractABI, &mockTransmitter{}, lp, lggr, ks, WithReportToEthMetadata(reportToEvmTxMeta)) require.NoError(t, err) digest, epoch, err := ot.LatestConfigDigestAndEpoch(testutils.Context(t)) @@ -157,6 +163,7 @@ func oneSignature() []ocrtypes.AttributedOnchainSignature { } func createContractTransmitter(ctx context.Context, t *testing.T, transmitter Transmitter, ops ...OCRTransmitterOption) *contractTransmitter { + ethKeyStore := mocks.NewEth(t) contractABI, err := abi.JSON(strings.NewReader(ocr2aggregator.OCR2AggregatorMetaData.ABI)) require.NoError(t, err) lp := lpmocks.NewLogPoller(t) @@ -169,6 +176,7 @@ func createContractTransmitter(ctx context.Context, t *testing.T, transmitter Tr transmitter, lp, logger.TestLogger(t), + ethKeyStore, ops..., ) require.NoError(t, err) diff --git a/core/services/relay/evm/dual_contract_transmitter.go b/core/services/relay/evm/dual_contract_transmitter.go index 86d7d38be2e..981c6120cdd 100644 --- a/core/services/relay/evm/dual_contract_transmitter.go +++ b/core/services/relay/evm/dual_contract_transmitter.go @@ -13,13 +13,14 @@ import ( "github.com/ethereum/go-ethereum/common" gethcommon "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" + "go.uber.org/multierr" "github.com/smartcontractkit/libocr/offchainreporting2plus/chains/evmutil" ocrtypes "github.com/smartcontractkit/libocr/offchainreporting2plus/types" "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" + "github.com/smartcontractkit/chainlink/v2/core/services/keystore" ) // TODO: Remove when new dual transmitter contracts are merged @@ -36,6 +37,7 @@ type dualContractTransmitter struct { contractReader contractReader lp logpoller.LogPoller lggr logger.Logger + ks keystore.Eth // Options transmitterOptions *transmitterOps } @@ -56,6 +58,7 @@ func NewOCRDualContractTransmitter( transmitter Transmitter, lp logpoller.LogPoller, lggr logger.Logger, + ethKeystore keystore.Eth, opts ...OCRTransmitterOption, ) (*dualContractTransmitter, error) { transmitted, ok := contractABI.Events["Transmitted"] @@ -71,7 +74,8 @@ func NewOCRDualContractTransmitter( transmittedEventSig: transmitted.ID, lp: lp, contractReader: caller, - lggr: logger.Named(lggr, "OCRDualContractTransmitter"), + lggr: logger.Named(lggr, "OCR2DualContractTransmitter"), + ks: ethKeystore, transmitterOptions: &transmitterOps{ reportToEvmTxMeta: reportToEvmTxMetaNoop, excludeSigs: false, @@ -173,8 +177,88 @@ func (oc *dualContractTransmitter) FromAccount(ctx context.Context) (ocrtypes.Ac return ocrtypes.Account(oc.transmitter.FromAddress(ctx).String()), nil } -func (oc *dualContractTransmitter) Start(ctx context.Context) error { return nil } -func (oc *dualContractTransmitter) Close() error { return nil } +func (oc *dualContractTransmitter) lockTransmitters(ctx context.Context) error { + err := oc.lockPrimary(ctx) + if err != nil { + return err + } + err = oc.lockSecondary(ctx) + if err != nil { + return multierr.Append(err, oc.unlockPrimary(ctx)) + } + return nil +} + +func (oc *dualContractTransmitter) unlockTransmitters(ctx context.Context) error { + return multierr.Append(oc.unlockPrimary(ctx), oc.unlockSecondary(ctx)) +} + +func (oc *dualContractTransmitter) unlockPrimary(ctx context.Context) error { + primaryAddress := oc.transmitter.FromAddress(ctx) + rmPrimary, err := oc.ks.GetResourceMutex(ctx, primaryAddress) + if err != nil { + return err + } + err = rmPrimary.Unlock(keystore.TXMv1) + if err != nil { + return err + } + oc.lggr.Debugf("Key %s has been unlocked for TXMv1", primaryAddress.String()) + return nil +} +func (oc *dualContractTransmitter) unlockSecondary(ctx context.Context) error { + secondaryAddress, err := oc.transmitter.SecondaryFromAddress(ctx) + if err != nil { + return err + } + rmSecondary, err := oc.ks.GetResourceMutex(ctx, secondaryAddress) + if err != nil { + return err + } + err = rmSecondary.Unlock(keystore.TXMv2) + if err != nil { + return err + } + oc.lggr.Debugf("Key %s has been unlocked for TXMv2", secondaryAddress.String()) + return nil +} + +func (oc *dualContractTransmitter) lockPrimary(ctx context.Context) error { + primaryAddress := oc.transmitter.FromAddress(ctx) + rmPrimary, err := oc.ks.GetResourceMutex(ctx, primaryAddress) + if err != nil { + return err + } + err = rmPrimary.TryLock(keystore.TXMv1) + if err != nil { + return err + } + oc.lggr.Debugf("Key %s has been locked for TXMv1", primaryAddress.String()) + return nil +} +func (oc *dualContractTransmitter) lockSecondary(ctx context.Context) error { + secondaryAddress, err := oc.transmitter.SecondaryFromAddress(ctx) + if err != nil { + return err + } + rmSecondary, err := oc.ks.GetResourceMutex(ctx, secondaryAddress) + if err != nil { + return err + } + err = rmSecondary.TryLock(keystore.TXMv2) + if err != nil { + return err + } + oc.lggr.Debugf("Key %s has been locked for TXMv2", secondaryAddress.String()) + return nil +} + +func (oc *dualContractTransmitter) Start(ctx context.Context) error { + return oc.lockTransmitters(ctx) +} +func (oc *dualContractTransmitter) Close() error { + return oc.unlockTransmitters(context.Background()) +} // Has no state/lifecycle so it's always healthy and ready func (oc *dualContractTransmitter) Ready() error { return nil } diff --git a/core/services/relay/evm/dual_contract_transmitter_test.go b/core/services/relay/evm/dual_contract_transmitter_test.go index a5110398159..e2376e46c76 100644 --- a/core/services/relay/evm/dual_contract_transmitter_test.go +++ b/core/services/relay/evm/dual_contract_transmitter_test.go @@ -19,6 +19,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/chains/evm/txmgr" "github.com/smartcontractkit/chainlink/v2/core/internal/testutils" "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/services/keystore/mocks" "github.com/smartcontractkit/libocr/gethwrappers2/ocr2aggregator" "github.com/smartcontractkit/libocr/offchainreporting2plus/chains/evmutil" @@ -32,6 +33,10 @@ type mockDualTransmitter struct { lastSecondaryPayload []byte } +func (m *mockDualTransmitter) SecondaryFromAddress(ctx context.Context) (gethcommon.Address, error) { + return gethcommon.Address{}, nil +} + func (*mockDualTransmitter) FromAddress(ctx context.Context) gethcommon.Address { return sampleAddressPrimary } @@ -49,6 +54,7 @@ func (m *mockDualTransmitter) CreateSecondaryEthTransaction(ctx context.Context, func TestDualContractTransmitter(t *testing.T) { t.Parallel() + keyStore := mocks.NewEth(t) lggr := logger.TestLogger(t) c := evmclimocks.NewClient(t) lp := lpmocks.NewLogPoller(t) @@ -64,8 +70,7 @@ func TestDualContractTransmitter(t *testing.T) { reportToEvmTxMeta := func(b []byte) (*txmgr.TxMeta, error) { return &txmgr.TxMeta{}, nil } - ot, err := NewOCRDualContractTransmitter(ctx, gethcommon.Address{}, c, contractABI, &mockDualTransmitter{}, lp, lggr, - WithReportToEthMetadata(reportToEvmTxMeta)) + ot, err := NewOCRDualContractTransmitter(ctx, gethcommon.Address{}, c, contractABI, &mockDualTransmitter{}, lp, lggr, keyStore, WithReportToEthMetadata(reportToEvmTxMeta)) require.NoError(t, err) digest, epoch, err := ot.LatestConfigDigestAndEpoch(testutils.Context(t)) require.NoError(t, err) @@ -145,6 +150,7 @@ func Test_dualContractTransmitter_Transmit_SignaturesAreTransmitted(t *testing.T } func createDualContractTransmitter(ctx context.Context, t *testing.T, transmitter Transmitter, ops ...OCRTransmitterOption) *dualContractTransmitter { + keyStore := mocks.NewEth(t) contractABI, err := abi.JSON(strings.NewReader(ocr2aggregator.OCR2AggregatorMetaData.ABI)) require.NoError(t, err) lp := lpmocks.NewLogPoller(t) @@ -157,6 +163,7 @@ func createDualContractTransmitter(ctx context.Context, t *testing.T, transmitte transmitter, lp, logger.TestLogger(t), + keyStore, ops..., ) require.NoError(t, err) diff --git a/core/services/relay/evm/evm.go b/core/services/relay/evm/evm.go index 609e3751bee..65777589580 100644 --- a/core/services/relay/evm/evm.go +++ b/core/services/relay/evm/evm.go @@ -953,6 +953,7 @@ func newOnChainContractTransmitter(ctx context.Context, lggr logger.Logger, rarg transmitter, configWatcher.chain.LogPoller(), lggr, + ethKeystore, ocrTransmitterOpts..., ) } @@ -972,6 +973,7 @@ func newOnChainDualContractTransmitter(ctx context.Context, lggr logger.Logger, transmitter, configWatcher.chain.LogPoller(), lggr, + ethKeystore, ocrTransmitterOpts..., ) } @@ -1039,6 +1041,7 @@ func generateTransmitterFrom(ctx context.Context, rargs commontypes.RelayArgs, e switch commontypes.OCR2PluginType(rargs.ProviderType) { case commontypes.Median: transmitter, err = ocrcommon.NewOCR2FeedsTransmitter( + ctx, configWatcher.chain.TxManager(), fromAddresses, common.HexToAddress(rargs.ContractID),