diff --git a/common/client/mock_hashable_test.go b/common/client/mock_hashable_test.go new file mode 100644 index 00000000000..d9f1670c073 --- /dev/null +++ b/common/client/mock_hashable_test.go @@ -0,0 +1,18 @@ +package client + +import "cmp" + +// Hashable - simple implementation of types.Hashable interface to be used as concrete type in tests +type Hashable string + +func (h Hashable) Cmp(c Hashable) int { + return cmp.Compare(h, c) +} + +func (h Hashable) String() string { + return string(h) +} + +func (h Hashable) Bytes() []byte { + return []byte(h) +} diff --git a/common/client/mock_head_test.go b/common/client/mock_head_test.go new file mode 100644 index 00000000000..b9cf0a5866f --- /dev/null +++ b/common/client/mock_head_test.go @@ -0,0 +1,57 @@ +// Code generated by mockery v2.35.4. DO NOT EDIT. + +package client + +import ( + utils "github.com/smartcontractkit/chainlink/v2/core/utils" + mock "github.com/stretchr/testify/mock" +) + +// mockHead is an autogenerated mock type for the Head type +type mockHead struct { + mock.Mock +} + +// BlockDifficulty provides a mock function with given fields: +func (_m *mockHead) BlockDifficulty() *utils.Big { + ret := _m.Called() + + var r0 *utils.Big + if rf, ok := ret.Get(0).(func() *utils.Big); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*utils.Big) + } + } + + return r0 +} + +// BlockNumber provides a mock function with given fields: +func (_m *mockHead) BlockNumber() int64 { + ret := _m.Called() + + var r0 int64 + if rf, ok := ret.Get(0).(func() int64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int64) + } + + return r0 +} + +// newMockHead creates a new instance of mockHead. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newMockHead(t interface { + mock.TestingT + Cleanup(func()) +}) *mockHead { + mock := &mockHead{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/common/client/mock_node_client_test.go b/common/client/mock_node_client_test.go new file mode 100644 index 00000000000..7c8eb69171f --- /dev/null +++ b/common/client/mock_node_client_test.go @@ -0,0 +1,168 @@ +// Code generated by mockery v2.35.4. DO NOT EDIT. + +package client + +import ( + context "context" + + types "github.com/smartcontractkit/chainlink/v2/common/types" + mock "github.com/stretchr/testify/mock" +) + +// mockNodeClient is an autogenerated mock type for the NodeClient type +type mockNodeClient[CHAIN_ID types.ID, HEAD Head] struct { + mock.Mock +} + +// ChainID provides a mock function with given fields: ctx +func (_m *mockNodeClient[CHAIN_ID, HEAD]) ChainID(ctx context.Context) (CHAIN_ID, error) { + ret := _m.Called(ctx) + + var r0 CHAIN_ID + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (CHAIN_ID, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) CHAIN_ID); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(CHAIN_ID) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ClientVersion provides a mock function with given fields: _a0 +func (_m *mockNodeClient[CHAIN_ID, HEAD]) ClientVersion(_a0 context.Context) (string, error) { + ret := _m.Called(_a0) + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (string, error)); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(context.Context) string); ok { + r0 = rf(_a0) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Close provides a mock function with given fields: +func (_m *mockNodeClient[CHAIN_ID, HEAD]) Close() { + _m.Called() +} + +// Dial provides a mock function with given fields: ctx +func (_m *mockNodeClient[CHAIN_ID, HEAD]) Dial(ctx context.Context) error { + ret := _m.Called(ctx) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DialHTTP provides a mock function with given fields: +func (_m *mockNodeClient[CHAIN_ID, HEAD]) DialHTTP() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DisconnectAll provides a mock function with given fields: +func (_m *mockNodeClient[CHAIN_ID, HEAD]) DisconnectAll() { + _m.Called() +} + +// SetAliveLoopSub provides a mock function with given fields: _a0 +func (_m *mockNodeClient[CHAIN_ID, HEAD]) SetAliveLoopSub(_a0 types.Subscription) { + _m.Called(_a0) +} + +// Subscribe provides a mock function with given fields: ctx, channel, args +func (_m *mockNodeClient[CHAIN_ID, HEAD]) Subscribe(ctx context.Context, channel chan<- HEAD, args ...interface{}) (types.Subscription, error) { + var _ca []interface{} + _ca = append(_ca, ctx, channel) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + var r0 types.Subscription + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, chan<- HEAD, ...interface{}) (types.Subscription, error)); ok { + return rf(ctx, channel, args...) + } + if rf, ok := ret.Get(0).(func(context.Context, chan<- HEAD, ...interface{}) types.Subscription); ok { + r0 = rf(ctx, channel, args...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(types.Subscription) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, chan<- HEAD, ...interface{}) error); ok { + r1 = rf(ctx, channel, args...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SubscribersCount provides a mock function with given fields: +func (_m *mockNodeClient[CHAIN_ID, HEAD]) SubscribersCount() int32 { + ret := _m.Called() + + var r0 int32 + if rf, ok := ret.Get(0).(func() int32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int32) + } + + return r0 +} + +// UnsubscribeAllExceptAliveLoop provides a mock function with given fields: +func (_m *mockNodeClient[CHAIN_ID, HEAD]) UnsubscribeAllExceptAliveLoop() { + _m.Called() +} + +// newMockNodeClient creates a new instance of mockNodeClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newMockNodeClient[CHAIN_ID types.ID, HEAD Head](t interface { + mock.TestingT + Cleanup(func()) +}) *mockNodeClient[CHAIN_ID, HEAD] { + mock := &mockNodeClient[CHAIN_ID, HEAD]{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/common/client/mock_node_selector_test.go b/common/client/mock_node_selector_test.go new file mode 100644 index 00000000000..e7b8d9ecb8d --- /dev/null +++ b/common/client/mock_node_selector_test.go @@ -0,0 +1,57 @@ +// Code generated by mockery v2.35.4. DO NOT EDIT. + +package client + +import ( + types "github.com/smartcontractkit/chainlink/v2/common/types" + mock "github.com/stretchr/testify/mock" +) + +// mockNodeSelector is an autogenerated mock type for the NodeSelector type +type mockNodeSelector[CHAIN_ID types.ID, HEAD Head, RPC NodeClient[CHAIN_ID, HEAD]] struct { + mock.Mock +} + +// Name provides a mock function with given fields: +func (_m *mockNodeSelector[CHAIN_ID, HEAD, RPC]) Name() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// Select provides a mock function with given fields: +func (_m *mockNodeSelector[CHAIN_ID, HEAD, RPC]) Select() Node[CHAIN_ID, HEAD, RPC] { + ret := _m.Called() + + var r0 Node[CHAIN_ID, HEAD, RPC] + if rf, ok := ret.Get(0).(func() Node[CHAIN_ID, HEAD, RPC]); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(Node[CHAIN_ID, HEAD, RPC]) + } + } + + return r0 +} + +// newMockNodeSelector creates a new instance of mockNodeSelector. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newMockNodeSelector[CHAIN_ID types.ID, HEAD Head, RPC NodeClient[CHAIN_ID, HEAD]](t interface { + mock.TestingT + Cleanup(func()) +}) *mockNodeSelector[CHAIN_ID, HEAD, RPC] { + mock := &mockNodeSelector[CHAIN_ID, HEAD, RPC]{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/common/client/mock_node_test.go b/common/client/mock_node_test.go new file mode 100644 index 00000000000..bd704cd2c6f --- /dev/null +++ b/common/client/mock_node_test.go @@ -0,0 +1,195 @@ +// Code generated by mockery v2.35.4. DO NOT EDIT. + +package client + +import ( + context "context" + + types "github.com/smartcontractkit/chainlink/v2/common/types" + mock "github.com/stretchr/testify/mock" + + utils "github.com/smartcontractkit/chainlink/v2/core/utils" +) + +// mockNode is an autogenerated mock type for the Node type +type mockNode[CHAIN_ID types.ID, HEAD Head, RPC NodeClient[CHAIN_ID, HEAD]] struct { + mock.Mock +} + +// Close provides a mock function with given fields: +func (_m *mockNode[CHAIN_ID, HEAD, RPC]) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ConfiguredChainID provides a mock function with given fields: +func (_m *mockNode[CHAIN_ID, HEAD, RPC]) ConfiguredChainID() CHAIN_ID { + ret := _m.Called() + + var r0 CHAIN_ID + if rf, ok := ret.Get(0).(func() CHAIN_ID); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(CHAIN_ID) + } + + return r0 +} + +// Name provides a mock function with given fields: +func (_m *mockNode[CHAIN_ID, HEAD, RPC]) Name() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// Order provides a mock function with given fields: +func (_m *mockNode[CHAIN_ID, HEAD, RPC]) Order() int32 { + ret := _m.Called() + + var r0 int32 + if rf, ok := ret.Get(0).(func() int32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int32) + } + + return r0 +} + +// RPC provides a mock function with given fields: +func (_m *mockNode[CHAIN_ID, HEAD, RPC]) RPC() RPC { + ret := _m.Called() + + var r0 RPC + if rf, ok := ret.Get(0).(func() RPC); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(RPC) + } + + return r0 +} + +// Start provides a mock function with given fields: _a0 +func (_m *mockNode[CHAIN_ID, HEAD, RPC]) Start(_a0 context.Context) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// State provides a mock function with given fields: +func (_m *mockNode[CHAIN_ID, HEAD, RPC]) State() nodeState { + ret := _m.Called() + + var r0 nodeState + if rf, ok := ret.Get(0).(func() nodeState); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(nodeState) + } + + return r0 +} + +// StateAndLatest provides a mock function with given fields: +func (_m *mockNode[CHAIN_ID, HEAD, RPC]) StateAndLatest() (nodeState, int64, *utils.Big) { + ret := _m.Called() + + var r0 nodeState + var r1 int64 + var r2 *utils.Big + if rf, ok := ret.Get(0).(func() (nodeState, int64, *utils.Big)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() nodeState); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(nodeState) + } + + if rf, ok := ret.Get(1).(func() int64); ok { + r1 = rf() + } else { + r1 = ret.Get(1).(int64) + } + + if rf, ok := ret.Get(2).(func() *utils.Big); ok { + r2 = rf() + } else { + if ret.Get(2) != nil { + r2 = ret.Get(2).(*utils.Big) + } + } + + return r0, r1, r2 +} + +// String provides a mock function with given fields: +func (_m *mockNode[CHAIN_ID, HEAD, RPC]) String() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// SubscribersCount provides a mock function with given fields: +func (_m *mockNode[CHAIN_ID, HEAD, RPC]) SubscribersCount() int32 { + ret := _m.Called() + + var r0 int32 + if rf, ok := ret.Get(0).(func() int32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int32) + } + + return r0 +} + +// UnsubscribeAllExceptAliveLoop provides a mock function with given fields: +func (_m *mockNode[CHAIN_ID, HEAD, RPC]) UnsubscribeAllExceptAliveLoop() { + _m.Called() +} + +// newMockNode creates a new instance of mockNode. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newMockNode[CHAIN_ID types.ID, HEAD Head, RPC NodeClient[CHAIN_ID, HEAD]](t interface { + mock.TestingT + Cleanup(func()) +}) *mockNode[CHAIN_ID, HEAD, RPC] { + mock := &mockNode[CHAIN_ID, HEAD, RPC]{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/common/client/mock_rpc_test.go b/common/client/mock_rpc_test.go new file mode 100644 index 00000000000..c378b9384e4 --- /dev/null +++ b/common/client/mock_rpc_test.go @@ -0,0 +1,608 @@ +// Code generated by mockery v2.35.4. DO NOT EDIT. + +package client + +import ( + big "math/big" + + assets "github.com/smartcontractkit/chainlink/v2/core/assets" + + context "context" + + feetypes "github.com/smartcontractkit/chainlink/v2/common/fee/types" + + mock "github.com/stretchr/testify/mock" + + types "github.com/smartcontractkit/chainlink/v2/common/types" +) + +// mockRPC is an autogenerated mock type for the RPC type +type mockRPC[CHAIN_ID types.ID, SEQ types.Sequence, ADDR types.Hashable, BLOCK_HASH types.Hashable, TX interface{}, TX_HASH types.Hashable, EVENT interface{}, EVENT_OPS interface{}, TX_RECEIPT types.Receipt[TX_HASH, BLOCK_HASH], FEE feetypes.Fee, HEAD types.Head[BLOCK_HASH]] struct { + mock.Mock +} + +// BalanceAt provides a mock function with given fields: ctx, accountAddress, blockNumber +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) BalanceAt(ctx context.Context, accountAddress ADDR, blockNumber *big.Int) (*big.Int, error) { + ret := _m.Called(ctx, accountAddress, blockNumber) + + var r0 *big.Int + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, ADDR, *big.Int) (*big.Int, error)); ok { + return rf(ctx, accountAddress, blockNumber) + } + if rf, ok := ret.Get(0).(func(context.Context, ADDR, *big.Int) *big.Int); ok { + r0 = rf(ctx, accountAddress, blockNumber) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*big.Int) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, ADDR, *big.Int) error); ok { + r1 = rf(ctx, accountAddress, blockNumber) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// BatchCallContext provides a mock function with given fields: ctx, b +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) BatchCallContext(ctx context.Context, b []interface{}) error { + ret := _m.Called(ctx, b) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, []interface{}) error); ok { + r0 = rf(ctx, b) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// BlockByHash provides a mock function with given fields: ctx, hash +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) BlockByHash(ctx context.Context, hash BLOCK_HASH) (HEAD, error) { + ret := _m.Called(ctx, hash) + + var r0 HEAD + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, BLOCK_HASH) (HEAD, error)); ok { + return rf(ctx, hash) + } + if rf, ok := ret.Get(0).(func(context.Context, BLOCK_HASH) HEAD); ok { + r0 = rf(ctx, hash) + } else { + r0 = ret.Get(0).(HEAD) + } + + if rf, ok := ret.Get(1).(func(context.Context, BLOCK_HASH) error); ok { + r1 = rf(ctx, hash) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// BlockByNumber provides a mock function with given fields: ctx, number +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) BlockByNumber(ctx context.Context, number *big.Int) (HEAD, error) { + ret := _m.Called(ctx, number) + + var r0 HEAD + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *big.Int) (HEAD, error)); ok { + return rf(ctx, number) + } + if rf, ok := ret.Get(0).(func(context.Context, *big.Int) HEAD); ok { + r0 = rf(ctx, number) + } else { + r0 = ret.Get(0).(HEAD) + } + + if rf, ok := ret.Get(1).(func(context.Context, *big.Int) error); ok { + r1 = rf(ctx, number) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CallContext provides a mock function with given fields: ctx, result, method, args +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, ctx, result, method) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, interface{}, string, ...interface{}) error); ok { + r0 = rf(ctx, result, method, args...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// CallContract provides a mock function with given fields: ctx, msg, blockNumber +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) CallContract(ctx context.Context, msg interface{}, blockNumber *big.Int) ([]byte, error) { + ret := _m.Called(ctx, msg, blockNumber) + + var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, interface{}, *big.Int) ([]byte, error)); ok { + return rf(ctx, msg, blockNumber) + } + if rf, ok := ret.Get(0).(func(context.Context, interface{}, *big.Int) []byte); ok { + r0 = rf(ctx, msg, blockNumber) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, interface{}, *big.Int) error); ok { + r1 = rf(ctx, msg, blockNumber) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ChainID provides a mock function with given fields: ctx +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) ChainID(ctx context.Context) (CHAIN_ID, error) { + ret := _m.Called(ctx) + + var r0 CHAIN_ID + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (CHAIN_ID, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) CHAIN_ID); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(CHAIN_ID) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ClientVersion provides a mock function with given fields: _a0 +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) ClientVersion(_a0 context.Context) (string, error) { + ret := _m.Called(_a0) + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (string, error)); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(context.Context) string); ok { + r0 = rf(_a0) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Close provides a mock function with given fields: +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) Close() { + _m.Called() +} + +// CodeAt provides a mock function with given fields: ctx, account, blockNumber +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) CodeAt(ctx context.Context, account ADDR, blockNumber *big.Int) ([]byte, error) { + ret := _m.Called(ctx, account, blockNumber) + + var r0 []byte + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, ADDR, *big.Int) ([]byte, error)); ok { + return rf(ctx, account, blockNumber) + } + if rf, ok := ret.Get(0).(func(context.Context, ADDR, *big.Int) []byte); ok { + r0 = rf(ctx, account, blockNumber) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, ADDR, *big.Int) error); ok { + r1 = rf(ctx, account, blockNumber) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Dial provides a mock function with given fields: ctx +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) Dial(ctx context.Context) error { + ret := _m.Called(ctx) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DialHTTP provides a mock function with given fields: +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) DialHTTP() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DisconnectAll provides a mock function with given fields: +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) DisconnectAll() { + _m.Called() +} + +// EstimateGas provides a mock function with given fields: ctx, call +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) EstimateGas(ctx context.Context, call interface{}) (uint64, error) { + ret := _m.Called(ctx, call) + + var r0 uint64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, interface{}) (uint64, error)); ok { + return rf(ctx, call) + } + if rf, ok := ret.Get(0).(func(context.Context, interface{}) uint64); ok { + r0 = rf(ctx, call) + } else { + r0 = ret.Get(0).(uint64) + } + + if rf, ok := ret.Get(1).(func(context.Context, interface{}) error); ok { + r1 = rf(ctx, call) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// FilterEvents provides a mock function with given fields: ctx, query +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) FilterEvents(ctx context.Context, query EVENT_OPS) ([]EVENT, error) { + ret := _m.Called(ctx, query) + + var r0 []EVENT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, EVENT_OPS) ([]EVENT, error)); ok { + return rf(ctx, query) + } + if rf, ok := ret.Get(0).(func(context.Context, EVENT_OPS) []EVENT); ok { + r0 = rf(ctx, query) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]EVENT) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, EVENT_OPS) error); ok { + r1 = rf(ctx, query) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// LINKBalance provides a mock function with given fields: ctx, accountAddress, linkAddress +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) LINKBalance(ctx context.Context, accountAddress ADDR, linkAddress ADDR) (*assets.Link, error) { + ret := _m.Called(ctx, accountAddress, linkAddress) + + var r0 *assets.Link + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, ADDR, ADDR) (*assets.Link, error)); ok { + return rf(ctx, accountAddress, linkAddress) + } + if rf, ok := ret.Get(0).(func(context.Context, ADDR, ADDR) *assets.Link); ok { + r0 = rf(ctx, accountAddress, linkAddress) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*assets.Link) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, ADDR, ADDR) error); ok { + r1 = rf(ctx, accountAddress, linkAddress) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// LatestBlockHeight provides a mock function with given fields: _a0 +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) LatestBlockHeight(_a0 context.Context) (*big.Int, error) { + ret := _m.Called(_a0) + + var r0 *big.Int + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (*big.Int, error)); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(context.Context) *big.Int); ok { + r0 = rf(_a0) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*big.Int) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PendingSequenceAt provides a mock function with given fields: ctx, addr +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) PendingSequenceAt(ctx context.Context, addr ADDR) (SEQ, error) { + ret := _m.Called(ctx, addr) + + var r0 SEQ + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, ADDR) (SEQ, error)); ok { + return rf(ctx, addr) + } + if rf, ok := ret.Get(0).(func(context.Context, ADDR) SEQ); ok { + r0 = rf(ctx, addr) + } else { + r0 = ret.Get(0).(SEQ) + } + + if rf, ok := ret.Get(1).(func(context.Context, ADDR) error); ok { + r1 = rf(ctx, addr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SendEmptyTransaction provides a mock function with given fields: ctx, newTxAttempt, seq, gasLimit, fee, fromAddress +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) SendEmptyTransaction(ctx context.Context, newTxAttempt func(SEQ, uint32, FEE, ADDR) (interface{}, error), seq SEQ, gasLimit uint32, fee FEE, fromAddress ADDR) (string, error) { + ret := _m.Called(ctx, newTxAttempt, seq, gasLimit, fee, fromAddress) + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, func(SEQ, uint32, FEE, ADDR) (interface{}, error), SEQ, uint32, FEE, ADDR) (string, error)); ok { + return rf(ctx, newTxAttempt, seq, gasLimit, fee, fromAddress) + } + if rf, ok := ret.Get(0).(func(context.Context, func(SEQ, uint32, FEE, ADDR) (interface{}, error), SEQ, uint32, FEE, ADDR) string); ok { + r0 = rf(ctx, newTxAttempt, seq, gasLimit, fee, fromAddress) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(context.Context, func(SEQ, uint32, FEE, ADDR) (interface{}, error), SEQ, uint32, FEE, ADDR) error); ok { + r1 = rf(ctx, newTxAttempt, seq, gasLimit, fee, fromAddress) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SendTransaction provides a mock function with given fields: ctx, tx +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) SendTransaction(ctx context.Context, tx TX) error { + ret := _m.Called(ctx, tx) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, TX) error); ok { + r0 = rf(ctx, tx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SequenceAt provides a mock function with given fields: ctx, accountAddress, blockNumber +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) SequenceAt(ctx context.Context, accountAddress ADDR, blockNumber *big.Int) (SEQ, error) { + ret := _m.Called(ctx, accountAddress, blockNumber) + + var r0 SEQ + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, ADDR, *big.Int) (SEQ, error)); ok { + return rf(ctx, accountAddress, blockNumber) + } + if rf, ok := ret.Get(0).(func(context.Context, ADDR, *big.Int) SEQ); ok { + r0 = rf(ctx, accountAddress, blockNumber) + } else { + r0 = ret.Get(0).(SEQ) + } + + if rf, ok := ret.Get(1).(func(context.Context, ADDR, *big.Int) error); ok { + r1 = rf(ctx, accountAddress, blockNumber) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SetAliveLoopSub provides a mock function with given fields: _a0 +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) SetAliveLoopSub(_a0 types.Subscription) { + _m.Called(_a0) +} + +// SimulateTransaction provides a mock function with given fields: ctx, tx +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) SimulateTransaction(ctx context.Context, tx TX) error { + ret := _m.Called(ctx, tx) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, TX) error); ok { + r0 = rf(ctx, tx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Subscribe provides a mock function with given fields: ctx, channel, args +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) Subscribe(ctx context.Context, channel chan<- HEAD, args ...interface{}) (types.Subscription, error) { + var _ca []interface{} + _ca = append(_ca, ctx, channel) + _ca = append(_ca, args...) + ret := _m.Called(_ca...) + + var r0 types.Subscription + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, chan<- HEAD, ...interface{}) (types.Subscription, error)); ok { + return rf(ctx, channel, args...) + } + if rf, ok := ret.Get(0).(func(context.Context, chan<- HEAD, ...interface{}) types.Subscription); ok { + r0 = rf(ctx, channel, args...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(types.Subscription) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, chan<- HEAD, ...interface{}) error); ok { + r1 = rf(ctx, channel, args...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SubscribersCount provides a mock function with given fields: +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) SubscribersCount() int32 { + ret := _m.Called() + + var r0 int32 + if rf, ok := ret.Get(0).(func() int32); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int32) + } + + return r0 +} + +// TokenBalance provides a mock function with given fields: ctx, accountAddress, tokenAddress +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) TokenBalance(ctx context.Context, accountAddress ADDR, tokenAddress ADDR) (*big.Int, error) { + ret := _m.Called(ctx, accountAddress, tokenAddress) + + var r0 *big.Int + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, ADDR, ADDR) (*big.Int, error)); ok { + return rf(ctx, accountAddress, tokenAddress) + } + if rf, ok := ret.Get(0).(func(context.Context, ADDR, ADDR) *big.Int); ok { + r0 = rf(ctx, accountAddress, tokenAddress) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*big.Int) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, ADDR, ADDR) error); ok { + r1 = rf(ctx, accountAddress, tokenAddress) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// TransactionByHash provides a mock function with given fields: ctx, txHash +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) TransactionByHash(ctx context.Context, txHash TX_HASH) (TX, error) { + ret := _m.Called(ctx, txHash) + + var r0 TX + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, TX_HASH) (TX, error)); ok { + return rf(ctx, txHash) + } + if rf, ok := ret.Get(0).(func(context.Context, TX_HASH) TX); ok { + r0 = rf(ctx, txHash) + } else { + r0 = ret.Get(0).(TX) + } + + if rf, ok := ret.Get(1).(func(context.Context, TX_HASH) error); ok { + r1 = rf(ctx, txHash) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// TransactionReceipt provides a mock function with given fields: ctx, txHash +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) TransactionReceipt(ctx context.Context, txHash TX_HASH) (TX_RECEIPT, error) { + ret := _m.Called(ctx, txHash) + + var r0 TX_RECEIPT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, TX_HASH) (TX_RECEIPT, error)); ok { + return rf(ctx, txHash) + } + if rf, ok := ret.Get(0).(func(context.Context, TX_HASH) TX_RECEIPT); ok { + r0 = rf(ctx, txHash) + } else { + r0 = ret.Get(0).(TX_RECEIPT) + } + + if rf, ok := ret.Get(1).(func(context.Context, TX_HASH) error); ok { + r1 = rf(ctx, txHash) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UnsubscribeAllExceptAliveLoop provides a mock function with given fields: +func (_m *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]) UnsubscribeAllExceptAliveLoop() { + _m.Called() +} + +// newMockRPC creates a new instance of mockRPC. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newMockRPC[CHAIN_ID types.ID, SEQ types.Sequence, ADDR types.Hashable, BLOCK_HASH types.Hashable, TX interface{}, TX_HASH types.Hashable, EVENT interface{}, EVENT_OPS interface{}, TX_RECEIPT types.Receipt[TX_HASH, BLOCK_HASH], FEE feetypes.Fee, HEAD types.Head[BLOCK_HASH]](t interface { + mock.TestingT + Cleanup(func()) +}) *mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD] { + mock := &mockRPC[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD]{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/common/client/mock_send_only_client_test.go b/common/client/mock_send_only_client_test.go new file mode 100644 index 00000000000..481b2602ea3 --- /dev/null +++ b/common/client/mock_send_only_client_test.go @@ -0,0 +1,72 @@ +// Code generated by mockery v2.35.4. DO NOT EDIT. + +package client + +import ( + context "context" + + types "github.com/smartcontractkit/chainlink/v2/common/types" + mock "github.com/stretchr/testify/mock" +) + +// mockSendOnlyClient is an autogenerated mock type for the sendOnlyClient type +type mockSendOnlyClient[CHAIN_ID types.ID] struct { + mock.Mock +} + +// ChainID provides a mock function with given fields: _a0 +func (_m *mockSendOnlyClient[CHAIN_ID]) ChainID(_a0 context.Context) (CHAIN_ID, error) { + ret := _m.Called(_a0) + + var r0 CHAIN_ID + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (CHAIN_ID, error)); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(context.Context) CHAIN_ID); ok { + r0 = rf(_a0) + } else { + r0 = ret.Get(0).(CHAIN_ID) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Close provides a mock function with given fields: +func (_m *mockSendOnlyClient[CHAIN_ID]) Close() { + _m.Called() +} + +// DialHTTP provides a mock function with given fields: +func (_m *mockSendOnlyClient[CHAIN_ID]) DialHTTP() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// newMockSendOnlyClient creates a new instance of mockSendOnlyClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newMockSendOnlyClient[CHAIN_ID types.ID](t interface { + mock.TestingT + Cleanup(func()) +}) *mockSendOnlyClient[CHAIN_ID] { + mock := &mockSendOnlyClient[CHAIN_ID]{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/common/client/mock_send_only_node_test.go b/common/client/mock_send_only_node_test.go new file mode 100644 index 00000000000..524d7d8a6c5 --- /dev/null +++ b/common/client/mock_send_only_node_test.go @@ -0,0 +1,127 @@ +// Code generated by mockery v2.35.4. DO NOT EDIT. + +package client + +import ( + context "context" + + types "github.com/smartcontractkit/chainlink/v2/common/types" + mock "github.com/stretchr/testify/mock" +) + +// mockSendOnlyNode is an autogenerated mock type for the SendOnlyNode type +type mockSendOnlyNode[CHAIN_ID types.ID, RPC sendOnlyClient[CHAIN_ID]] struct { + mock.Mock +} + +// Close provides a mock function with given fields: +func (_m *mockSendOnlyNode[CHAIN_ID, RPC]) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ConfiguredChainID provides a mock function with given fields: +func (_m *mockSendOnlyNode[CHAIN_ID, RPC]) ConfiguredChainID() CHAIN_ID { + ret := _m.Called() + + var r0 CHAIN_ID + if rf, ok := ret.Get(0).(func() CHAIN_ID); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(CHAIN_ID) + } + + return r0 +} + +// Name provides a mock function with given fields: +func (_m *mockSendOnlyNode[CHAIN_ID, RPC]) Name() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// RPC provides a mock function with given fields: +func (_m *mockSendOnlyNode[CHAIN_ID, RPC]) RPC() RPC { + ret := _m.Called() + + var r0 RPC + if rf, ok := ret.Get(0).(func() RPC); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(RPC) + } + + return r0 +} + +// Start provides a mock function with given fields: _a0 +func (_m *mockSendOnlyNode[CHAIN_ID, RPC]) Start(_a0 context.Context) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// State provides a mock function with given fields: +func (_m *mockSendOnlyNode[CHAIN_ID, RPC]) State() nodeState { + ret := _m.Called() + + var r0 nodeState + if rf, ok := ret.Get(0).(func() nodeState); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(nodeState) + } + + return r0 +} + +// String provides a mock function with given fields: +func (_m *mockSendOnlyNode[CHAIN_ID, RPC]) String() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// newMockSendOnlyNode creates a new instance of mockSendOnlyNode. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newMockSendOnlyNode[CHAIN_ID types.ID, RPC sendOnlyClient[CHAIN_ID]](t interface { + mock.TestingT + Cleanup(func()) +}) *mockSendOnlyNode[CHAIN_ID, RPC] { + mock := &mockSendOnlyNode[CHAIN_ID, RPC]{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/common/client/multi_node.go b/common/client/multi_node.go index f54e3115d95..c268cfb23cd 100644 --- a/common/client/multi_node.go +++ b/common/client/multi_node.go @@ -30,25 +30,6 @@ var ( ErroringNodeError = fmt.Errorf("no live nodes available") ) -const ( - NodeSelectionModeHighestHead = "HighestHead" - NodeSelectionModeRoundRobin = "RoundRobin" - NodeSelectionModeTotalDifficulty = "TotalDifficulty" - NodeSelectionModePriorityLevel = "PriorityLevel" -) - -type NodeSelector[ - CHAIN_ID types.ID, - HEAD Head, - RPC NodeClient[CHAIN_ID, HEAD], -] interface { - // Select returns a Node, or nil if none can be selected. - // Implementation must be thread-safe. - Select() Node[CHAIN_ID, HEAD, RPC] - // Name returns the strategy name, e.g. "HighestHead" or "RoundRobin" - Name() string -} - // MultiNode is a generalized multi node client interface that includes methods to interact with different chains. // It also handles multiple node RPC connections simultaneously. type MultiNode[ @@ -113,6 +94,7 @@ type multiNode[ leaseDuration time.Duration leaseTicker *time.Ticker chainFamily string + reportInterval time.Duration activeMu sync.RWMutex activeNode Node[CHAIN_ID, HEAD, RPC_CLIENT] @@ -148,23 +130,13 @@ func NewMultiNode[ chainFamily string, sendOnlyErrorParser func(err error) SendTxReturnCode, ) MultiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT] { - nodeSelector := func() NodeSelector[CHAIN_ID, HEAD, RPC_CLIENT] { - switch selectionMode { - case NodeSelectionModeHighestHead: - return NewHighestHeadNodeSelector[CHAIN_ID, HEAD, RPC_CLIENT](nodes) - case NodeSelectionModeRoundRobin: - return NewRoundRobinSelector[CHAIN_ID, HEAD, RPC_CLIENT](nodes) - case NodeSelectionModeTotalDifficulty: - return NewTotalDifficultyNodeSelector[CHAIN_ID, HEAD, RPC_CLIENT](nodes) - case NodeSelectionModePriorityLevel: - return NewPriorityLevelNodeSelector[CHAIN_ID, HEAD, RPC_CLIENT](nodes) - default: - panic(fmt.Sprintf("unsupported NodeSelectionMode: %s", selectionMode)) - } - }() + nodeSelector := newNodeSelector(selectionMode, nodes) lggr := logger.Named("MultiNode").With("chainID", chainID.String()) + // Prometheus' default interval is 15s, set this to under 7.5s to avoid + // aliasing (see: https://en.wikipedia.org/wiki/Nyquist_frequency) + const reportInterval = 6500 * time.Millisecond c := &multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OPS, TX_RECEIPT, FEE, HEAD, RPC_CLIENT]{ nodes: nodes, sendonlys: sendonlys, @@ -178,6 +150,7 @@ func NewMultiNode[ leaseDuration: leaseDuration, chainFamily: chainFamily, sendOnlyErrorParser: sendOnlyErrorParser, + reportInterval: reportInterval, } c.logger.Debugf("The MultiNode is configured to use NodeSelectionMode: %s", selectionMode) @@ -341,10 +314,7 @@ func (c *multiNode[CHAIN_ID, SEQ, ADDR, BLOCK_HASH, TX, TX_HASH, EVENT, EVENT_OP c.report() - // Prometheus' default interval is 15s, set this to under 7.5s to avoid - // aliasing (see: https://en.wikipedia.org/wiki/Nyquist_frequency) - reportInterval := 6500 * time.Millisecond - monitor := time.NewTicker(utils.WithJitter(reportInterval)) + monitor := time.NewTicker(utils.WithJitter(c.reportInterval)) defer monitor.Stop() for { diff --git a/common/client/multi_node_test.go b/common/client/multi_node_test.go new file mode 100644 index 00000000000..1fddbc3be3c --- /dev/null +++ b/common/client/multi_node_test.go @@ -0,0 +1,635 @@ +package client + +import ( + "fmt" + "math/rand" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/smartcontractkit/chainlink-relay/pkg/utils/tests" + + "github.com/smartcontractkit/chainlink/v2/common/types" + "github.com/smartcontractkit/chainlink/v2/core/config" + "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/utils" +) + +type multiNodeRPCClient RPC[types.ID, *utils.Big, Hashable, Hashable, any, Hashable, any, any, + types.Receipt[Hashable, Hashable], Hashable, types.Head[Hashable]] + +type testMultiNode struct { + *multiNode[types.ID, *utils.Big, Hashable, Hashable, any, Hashable, any, any, + types.Receipt[Hashable, Hashable], Hashable, types.Head[Hashable], multiNodeRPCClient] +} + +type multiNodeOpts struct { + logger logger.Logger + selectionMode string + leaseDuration time.Duration + noNewHeadsThreshold time.Duration + nodes []Node[types.ID, types.Head[Hashable], multiNodeRPCClient] + sendonlys []SendOnlyNode[types.ID, multiNodeRPCClient] + chainID types.ID + chainType config.ChainType + chainFamily string + sendOnlyErrorParser func(err error) SendTxReturnCode +} + +func newTestMultiNode(t *testing.T, opts multiNodeOpts) testMultiNode { + if opts.logger == nil { + opts.logger = logger.TestLogger(t) + } + + result := NewMultiNode[types.ID, *utils.Big, Hashable, Hashable, any, Hashable, any, any, + types.Receipt[Hashable, Hashable], Hashable, types.Head[Hashable], multiNodeRPCClient](opts.logger, + opts.selectionMode, opts.leaseDuration, opts.noNewHeadsThreshold, opts.nodes, opts.sendonlys, + opts.chainID, opts.chainType, opts.chainFamily, opts.sendOnlyErrorParser) + return testMultiNode{ + result.(*multiNode[types.ID, *utils.Big, Hashable, Hashable, any, Hashable, any, any, + types.Receipt[Hashable, Hashable], Hashable, types.Head[Hashable], multiNodeRPCClient]), + } +} + +func newMultiNodeRPCClient(t *testing.T) *mockRPC[types.ID, *utils.Big, Hashable, Hashable, any, Hashable, any, any, + types.Receipt[Hashable, Hashable], Hashable, types.Head[Hashable]] { + return newMockRPC[types.ID, *utils.Big, Hashable, Hashable, any, Hashable, any, any, + types.Receipt[Hashable, Hashable], Hashable, types.Head[Hashable]](t) +} + +func newHealthyNode(t *testing.T, chainID types.ID) *mockNode[types.ID, types.Head[Hashable], multiNodeRPCClient] { + return newNodeWithState(t, chainID, nodeStateAlive) +} + +func newNodeWithState(t *testing.T, chainID types.ID, state nodeState) *mockNode[types.ID, types.Head[Hashable], multiNodeRPCClient] { + node := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + node.On("ConfiguredChainID").Return(chainID).Once() + node.On("Start", mock.Anything).Return(nil).Once() + node.On("Close").Return(nil).Once() + node.On("State").Return(state).Maybe() + node.On("String").Return(fmt.Sprintf("healthy_node_%d", rand.Int())).Maybe() + return node +} +func TestMultiNode_Dial(t *testing.T) { + t.Parallel() + + newMockNode := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient] + newMockSendOnlyNode := newMockSendOnlyNode[types.ID, multiNodeRPCClient] + + t.Run("Fails without nodes", func(t *testing.T) { + t.Parallel() + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: types.RandomID(), + }) + err := mn.Dial(tests.Context(t)) + assert.EqualError(t, err, fmt.Sprintf("no available nodes for chain %s", mn.chainID.String())) + }) + t.Run("Fails with wrong node's chainID", func(t *testing.T) { + t.Parallel() + node := newMockNode(t) + multiNodeChainID := types.NewIDFromInt(10) + nodeChainID := types.NewIDFromInt(11) + node.On("ConfiguredChainID").Return(nodeChainID).Twice() + const nodeName = "nodeName" + node.On("String").Return(nodeName).Once() + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: multiNodeChainID, + nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node}, + }) + err := mn.Dial(tests.Context(t)) + assert.EqualError(t, err, fmt.Sprintf("node %s has configured chain ID %s which does not match multinode configured chain ID of %s", nodeName, nodeChainID, mn.chainID)) + }) + t.Run("Fails if node fails", func(t *testing.T) { + t.Parallel() + node := newMockNode(t) + chainID := types.RandomID() + node.On("ConfiguredChainID").Return(chainID).Once() + expectedError := errors.New("failed to start node") + node.On("Start", mock.Anything).Return(expectedError).Once() + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: chainID, + nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node}, + }) + err := mn.Dial(tests.Context(t)) + assert.EqualError(t, err, expectedError.Error()) + }) + + t.Run("Closes started nodes on failure", func(t *testing.T) { + t.Parallel() + chainID := types.RandomID() + node1 := newHealthyNode(t, chainID) + node2 := newMockNode(t) + node2.On("ConfiguredChainID").Return(chainID).Once() + expectedError := errors.New("failed to start node") + node2.On("Start", mock.Anything).Return(expectedError).Once() + + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: chainID, + nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node1, node2}, + }) + err := mn.Dial(tests.Context(t)) + assert.EqualError(t, err, expectedError.Error()) + }) + t.Run("Fails with wrong send only node's chainID", func(t *testing.T) { + t.Parallel() + multiNodeChainID := types.NewIDFromInt(10) + node := newHealthyNode(t, multiNodeChainID) + sendOnly := newMockSendOnlyNode(t) + sendOnlyChainID := types.NewIDFromInt(11) + sendOnly.On("ConfiguredChainID").Return(sendOnlyChainID).Twice() + const sendOnlyName = "sendOnlyNodeName" + sendOnly.On("String").Return(sendOnlyName).Once() + + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: multiNodeChainID, + nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node}, + sendonlys: []SendOnlyNode[types.ID, multiNodeRPCClient]{sendOnly}, + }) + err := mn.Dial(tests.Context(t)) + assert.EqualError(t, err, fmt.Sprintf("sendonly node %s has configured chain ID %s which does not match multinode configured chain ID of %s", sendOnlyName, sendOnlyChainID, mn.chainID)) + }) + + newHealthySendOnly := func(t *testing.T, chainID types.ID) *mockSendOnlyNode[types.ID, multiNodeRPCClient] { + node := newMockSendOnlyNode(t) + node.On("ConfiguredChainID").Return(chainID).Once() + node.On("Start", mock.Anything).Return(nil).Once() + node.On("Close").Return(nil).Once() + return node + } + t.Run("Fails on send only node failure", func(t *testing.T) { + t.Parallel() + chainID := types.NewIDFromInt(10) + node := newHealthyNode(t, chainID) + sendOnly1 := newHealthySendOnly(t, chainID) + sendOnly2 := newMockSendOnlyNode(t) + sendOnly2.On("ConfiguredChainID").Return(chainID).Once() + expectedError := errors.New("failed to start send only node") + sendOnly2.On("Start", mock.Anything).Return(expectedError).Once() + + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: chainID, + nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node}, + sendonlys: []SendOnlyNode[types.ID, multiNodeRPCClient]{sendOnly1, sendOnly2}, + }) + err := mn.Dial(tests.Context(t)) + assert.EqualError(t, err, expectedError.Error()) + }) + t.Run("Starts successfully with healthy nodes", func(t *testing.T) { + t.Parallel() + chainID := types.NewIDFromInt(10) + node := newHealthyNode(t, chainID) + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: chainID, + nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node}, + sendonlys: []SendOnlyNode[types.ID, multiNodeRPCClient]{newHealthySendOnly(t, chainID)}, + }) + defer func() { assert.NoError(t, mn.Close()) }() + err := mn.Dial(tests.Context(t)) + require.NoError(t, err) + selectedNode, err := mn.selectNode() + require.NoError(t, err) + assert.Equal(t, node, selectedNode) + }) +} + +func TestMultiNode_Report(t *testing.T) { + t.Parallel() + t.Run("Dial starts periodical reporting", func(t *testing.T) { + t.Parallel() + chainID := types.RandomID() + node1 := newHealthyNode(t, chainID) + node2 := newNodeWithState(t, chainID, nodeStateOutOfSync) + lggr, observedLogs := logger.TestLoggerObserved(t, zap.WarnLevel) + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: chainID, + nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node1, node2}, + logger: lggr, + }) + mn.reportInterval = tests.TestInterval + defer func() { assert.NoError(t, mn.Close()) }() + err := mn.Dial(tests.Context(t)) + require.NoError(t, err) + tests.AssertLogCountEventually(t, observedLogs, "At least one primary node is dead: 1/2 nodes are alive", 2) + }) + t.Run("Report critical error on all node failure", func(t *testing.T) { + t.Parallel() + chainID := types.RandomID() + node := newNodeWithState(t, chainID, nodeStateOutOfSync) + lggr, observedLogs := logger.TestLoggerObserved(t, zap.WarnLevel) + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: chainID, + nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node}, + logger: lggr, + }) + mn.reportInterval = tests.TestInterval + defer func() { assert.NoError(t, mn.Close()) }() + err := mn.Dial(tests.Context(t)) + require.NoError(t, err) + tests.AssertLogCountEventually(t, observedLogs, "no primary nodes available: 0/1 nodes are alive", 2) + err = mn.Healthy() + require.Error(t, err) + assert.Contains(t, err.Error(), "no primary nodes available: 0/1 nodes are alive") + }) +} + +func TestMultiNode_CheckLease(t *testing.T) { + t.Parallel() + t.Run("Round robin disables lease check", func(t *testing.T) { + t.Parallel() + chainID := types.RandomID() + node := newHealthyNode(t, chainID) + lggr, observedLogs := logger.TestLoggerObserved(t, zap.InfoLevel) + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: chainID, + logger: lggr, + nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node}, + }) + defer func() { assert.NoError(t, mn.Close()) }() + err := mn.Dial(tests.Context(t)) + require.NoError(t, err) + tests.RequireLogMessage(t, observedLogs, "Best node switching is disabled") + }) + t.Run("Misconfigured lease check period won't start", func(t *testing.T) { + t.Parallel() + chainID := types.RandomID() + node := newHealthyNode(t, chainID) + lggr, observedLogs := logger.TestLoggerObserved(t, zap.InfoLevel) + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeHighestHead, + chainID: chainID, + logger: lggr, + nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node}, + leaseDuration: 0, + }) + defer func() { assert.NoError(t, mn.Close()) }() + err := mn.Dial(tests.Context(t)) + require.NoError(t, err) + tests.RequireLogMessage(t, observedLogs, "Best node switching is disabled") + }) + t.Run("Lease check updates active node", func(t *testing.T) { + t.Parallel() + chainID := types.RandomID() + node := newHealthyNode(t, chainID) + node.On("SubscribersCount").Return(int32(2)) + node.On("UnsubscribeAllExceptAliveLoop") + bestNode := newHealthyNode(t, chainID) + nodeSelector := newMockNodeSelector[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + nodeSelector.On("Select").Return(bestNode) + lggr, observedLogs := logger.TestLoggerObserved(t, zap.InfoLevel) + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeHighestHead, + chainID: chainID, + logger: lggr, + nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node, bestNode}, + leaseDuration: tests.TestInterval, + }) + defer func() { assert.NoError(t, mn.Close()) }() + mn.nodeSelector = nodeSelector + err := mn.Dial(tests.Context(t)) + require.NoError(t, err) + tests.AssertLogEventually(t, observedLogs, fmt.Sprintf("Switching to best node from %q to %q", node.String(), bestNode.String())) + tests.AssertEventually(t, func() bool { + mn.activeMu.RLock() + active := mn.activeNode + mn.activeMu.RUnlock() + return bestNode == active + }) + }) + t.Run("NodeStates returns proper states", func(t *testing.T) { + t.Parallel() + chainID := types.NewIDFromInt(10) + nodes := map[string]nodeState{ + "node_1": nodeStateAlive, + "node_2": nodeStateUnreachable, + "node_3": nodeStateDialed, + } + + opts := multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: chainID, + } + + expectedResult := map[string]string{} + for name, state := range nodes { + node := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + node.On("Name").Return(name).Once() + node.On("State").Return(state).Once() + opts.nodes = append(opts.nodes, node) + + sendOnly := newMockSendOnlyNode[types.ID, multiNodeRPCClient](t) + sendOnlyName := "send_only_" + name + sendOnly.On("Name").Return(sendOnlyName).Once() + sendOnly.On("State").Return(state).Once() + opts.sendonlys = append(opts.sendonlys, sendOnly) + + expectedResult[name] = state.String() + expectedResult[sendOnlyName] = state.String() + } + + mn := newTestMultiNode(t, opts) + states := mn.NodeStates() + assert.Equal(t, expectedResult, states) + }) +} + +func TestMultiNode_selectNode(t *testing.T) { + t.Parallel() + t.Run("Returns same node, if it's still healthy", func(t *testing.T) { + t.Parallel() + chainID := types.RandomID() + node1 := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + node1.On("State").Return(nodeStateAlive).Once() + node1.On("String").Return("node1").Maybe() + node2 := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + node2.On("String").Return("node2").Maybe() + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: chainID, + nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{node1, node2}, + }) + nodeSelector := newMockNodeSelector[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + nodeSelector.On("Select").Return(node1).Once() + mn.nodeSelector = nodeSelector + prevActiveNode, err := mn.selectNode() + require.NoError(t, err) + require.Equal(t, node1.String(), prevActiveNode.String()) + newActiveNode, err := mn.selectNode() + require.NoError(t, err) + require.Equal(t, prevActiveNode.String(), newActiveNode.String()) + + }) + t.Run("Updates node if active is not healthy", func(t *testing.T) { + t.Parallel() + chainID := types.RandomID() + oldBest := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + oldBest.On("String").Return("oldBest").Maybe() + newBest := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + newBest.On("String").Return("newBest").Maybe() + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: chainID, + nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{oldBest, newBest}, + }) + nodeSelector := newMockNodeSelector[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + nodeSelector.On("Select").Return(oldBest).Once() + mn.nodeSelector = nodeSelector + activeNode, err := mn.selectNode() + require.NoError(t, err) + require.Equal(t, oldBest.String(), activeNode.String()) + // old best died, so we should replace it + oldBest.On("State").Return(nodeStateOutOfSync).Twice() + nodeSelector.On("Select").Return(newBest).Once() + newActiveNode, err := mn.selectNode() + require.NoError(t, err) + require.Equal(t, newBest.String(), newActiveNode.String()) + + }) + t.Run("No active nodes - reports critical error", func(t *testing.T) { + t.Parallel() + chainID := types.RandomID() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.InfoLevel) + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: chainID, + logger: lggr, + }) + nodeSelector := newMockNodeSelector[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + nodeSelector.On("Select").Return(nil).Once() + nodeSelector.On("Name").Return("MockedNodeSelector").Once() + mn.nodeSelector = nodeSelector + node, err := mn.selectNode() + require.EqualError(t, err, ErroringNodeError.Error()) + require.Nil(t, node) + tests.RequireLogMessage(t, observedLogs, "No live RPC nodes available") + + }) +} + +func TestMultiNode_nLiveNodes(t *testing.T) { + t.Parallel() + type nodeParams struct { + BlockNumber int64 + TotalDifficulty *utils.Big + State nodeState + } + testCases := []struct { + Name string + ExpectedNLiveNodes int + ExpectedBlockNumber int64 + ExpectedTotalDifficulty *utils.Big + NodeParams []nodeParams + }{ + { + Name: "no nodes", + ExpectedTotalDifficulty: utils.NewBigI(0), + }, + { + Name: "Best node is not healthy", + ExpectedTotalDifficulty: utils.NewBigI(10), + ExpectedBlockNumber: 20, + ExpectedNLiveNodes: 3, + NodeParams: []nodeParams{ + { + State: nodeStateOutOfSync, + BlockNumber: 1000, + TotalDifficulty: utils.NewBigI(2000), + }, + { + State: nodeStateAlive, + BlockNumber: 20, + TotalDifficulty: utils.NewBigI(9), + }, + { + State: nodeStateAlive, + BlockNumber: 19, + TotalDifficulty: utils.NewBigI(10), + }, + { + State: nodeStateAlive, + BlockNumber: 11, + TotalDifficulty: nil, + }, + }, + }, + } + + chainID := types.RandomID() + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: chainID, + }) + for i := range testCases { + tc := testCases[i] + t.Run(tc.Name, func(t *testing.T) { + for _, params := range tc.NodeParams { + node := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + node.On("StateAndLatest").Return(params.State, params.BlockNumber, params.TotalDifficulty) + mn.nodes = append(mn.nodes, node) + } + + nNodes, blockNum, td := mn.nLiveNodes() + assert.Equal(t, tc.ExpectedNLiveNodes, nNodes) + assert.Equal(t, tc.ExpectedTotalDifficulty, td) + assert.Equal(t, tc.ExpectedBlockNumber, blockNum) + }) + } +} + +func TestMultiNode_BatchCallContextAll(t *testing.T) { + t.Parallel() + t.Run("Fails if failed to select active node", func(t *testing.T) { + chainID := types.RandomID() + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: chainID, + }) + nodeSelector := newMockNodeSelector[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + nodeSelector.On("Select").Return(nil).Once() + nodeSelector.On("Name").Return("MockedNodeSelector").Once() + mn.nodeSelector = nodeSelector + err := mn.BatchCallContextAll(tests.Context(t), nil) + require.EqualError(t, err, ErroringNodeError.Error()) + }) + t.Run("Returns error if RPC call fails for active node", func(t *testing.T) { + chainID := types.RandomID() + rpc := newMultiNodeRPCClient(t) + expectedError := errors.New("rpc failed to do the batch call") + rpc.On("BatchCallContext", mock.Anything, mock.Anything).Return(expectedError).Once() + node := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + node.On("RPC").Return(rpc) + nodeSelector := newMockNodeSelector[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + nodeSelector.On("Select").Return(node).Once() + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: chainID, + }) + mn.nodeSelector = nodeSelector + err := mn.BatchCallContextAll(tests.Context(t), nil) + require.EqualError(t, err, expectedError.Error()) + }) + t.Run("Waits for all nodes to complete the call and logs results", func(t *testing.T) { + // setup RPCs + failedRPC := newMultiNodeRPCClient(t) + failedRPC.On("BatchCallContext", mock.Anything, mock.Anything). + Return(errors.New("rpc failed to do the batch call")).Once() + okRPC := newMultiNodeRPCClient(t) + okRPC.On("BatchCallContext", mock.Anything, mock.Anything).Return(nil).Twice() + + // setup ok and failed auxiliary nodes + okNode := newMockSendOnlyNode[types.ID, multiNodeRPCClient](t) + okNode.On("RPC").Return(okRPC).Once() + failedNode := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + failedNode.On("RPC").Return(failedRPC).Once() + + // setup main node + mainNode := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + mainNode.On("RPC").Return(okRPC) + nodeSelector := newMockNodeSelector[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + nodeSelector.On("Select").Return(mainNode).Once() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: types.RandomID(), + nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{failedNode, mainNode}, + sendonlys: []SendOnlyNode[types.ID, multiNodeRPCClient]{okNode}, + logger: lggr, + }) + mn.nodeSelector = nodeSelector + + err := mn.BatchCallContextAll(tests.Context(t), nil) + require.NoError(t, err) + tests.RequireLogMessage(t, observedLogs, "Secondary node BatchCallContext failed") + }) +} + +func TestMultiNode_SendTransaction(t *testing.T) { + t.Parallel() + t.Run("Fails if failed to select active node", func(t *testing.T) { + chainID := types.RandomID() + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: chainID, + }) + nodeSelector := newMockNodeSelector[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + nodeSelector.On("Select").Return(nil).Once() + nodeSelector.On("Name").Return("MockedNodeSelector").Once() + mn.nodeSelector = nodeSelector + err := mn.SendTransaction(tests.Context(t), nil) + require.EqualError(t, err, ErroringNodeError.Error()) + }) + t.Run("Returns error if RPC call fails for active node", func(t *testing.T) { + chainID := types.RandomID() + rpc := newMultiNodeRPCClient(t) + expectedError := errors.New("rpc failed to do the batch call") + rpc.On("SendTransaction", mock.Anything, mock.Anything).Return(expectedError).Once() + node := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + node.On("RPC").Return(rpc) + nodeSelector := newMockNodeSelector[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + nodeSelector.On("Select").Return(node).Once() + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: chainID, + }) + mn.nodeSelector = nodeSelector + err := mn.SendTransaction(tests.Context(t), nil) + require.EqualError(t, err, expectedError.Error()) + }) + t.Run("Returns result of main node and logs secondary nodes results", func(t *testing.T) { + // setup RPCs + failedRPC := newMultiNodeRPCClient(t) + failedRPC.On("SendTransaction", mock.Anything, mock.Anything). + Return(errors.New("rpc failed to do the batch call")).Once() + okRPC := newMultiNodeRPCClient(t) + okRPC.On("SendTransaction", mock.Anything, mock.Anything).Return(nil).Twice() + + // setup ok and failed auxiliary nodes + okNode := newMockSendOnlyNode[types.ID, multiNodeRPCClient](t) + okNode.On("RPC").Return(okRPC).Once() + okNode.On("String").Return("okNode") + failedNode := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + failedNode.On("RPC").Return(failedRPC).Once() + failedNode.On("String").Return("failedNode") + + // setup main node + mainNode := newMockNode[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + mainNode.On("RPC").Return(okRPC) + nodeSelector := newMockNodeSelector[types.ID, types.Head[Hashable], multiNodeRPCClient](t) + nodeSelector.On("Select").Return(mainNode).Once() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + mn := newTestMultiNode(t, multiNodeOpts{ + selectionMode: NodeSelectionModeRoundRobin, + chainID: types.RandomID(), + nodes: []Node[types.ID, types.Head[Hashable], multiNodeRPCClient]{failedNode, mainNode}, + sendonlys: []SendOnlyNode[types.ID, multiNodeRPCClient]{okNode}, + logger: lggr, + sendOnlyErrorParser: func(err error) SendTxReturnCode { + if err != nil { + return Fatal + } + + return Successful + }, + }) + mn.nodeSelector = nodeSelector + + err := mn.SendTransaction(tests.Context(t), nil) + require.NoError(t, err) + tests.AssertLogEventually(t, observedLogs, "Sendonly node sent transaction") + tests.AssertLogEventually(t, observedLogs, "RPC returned error") + }) +} diff --git a/common/client/node.go b/common/client/node.go index 20d098e03f7..f28a171a558 100644 --- a/common/client/node.go +++ b/common/client/node.go @@ -44,6 +44,7 @@ type NodeConfig interface { SyncThreshold() uint32 } +//go:generate mockery --quiet --name Node --structname mockNode --filename "mock_node_test.go" --inpackage --case=underscore type Node[ CHAIN_ID types.ID, HEAD Head, @@ -177,19 +178,21 @@ func (n *node[CHAIN_ID, HEAD, RPC]) UnsubscribeAllExceptAliveLoop() { } func (n *node[CHAIN_ID, HEAD, RPC]) Close() error { - return n.StopOnce(n.name, func() error { - defer func() { - n.wg.Wait() - n.rpc.Close() - }() + return n.StopOnce(n.name, n.close) +} - n.stateMu.Lock() - defer n.stateMu.Unlock() +func (n *node[CHAIN_ID, HEAD, RPC]) close() error { + defer func() { + n.wg.Wait() + n.rpc.Close() + }() - n.cancelNodeCtx() - n.state = nodeStateClosed - return nil - }) + n.stateMu.Lock() + defer n.stateMu.Unlock() + + n.cancelNodeCtx() + n.state = nodeStateClosed + return nil } // Start dials and verifies the node diff --git a/common/client/node_fsm_test.go b/common/client/node_fsm_test.go new file mode 100644 index 00000000000..87e90846699 --- /dev/null +++ b/common/client/node_fsm_test.go @@ -0,0 +1,108 @@ +package client + +import ( + "slices" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/smartcontractkit/chainlink/v2/common/types" +) + +type fnMock struct{ calls int } + +func (fm *fnMock) Fn() { + fm.calls++ +} + +func (fm *fnMock) AssertNotCalled(t *testing.T) { + assert.Equal(t, 0, fm.calls) +} + +func (fm *fnMock) AssertCalled(t *testing.T) { + assert.Greater(t, fm.calls, 0) +} + +func newTestTransitionNode(t *testing.T, rpc *mockNodeClient[types.ID, Head]) testNode { + return newTestNode(t, testNodeOpts{rpc: rpc}) +} + +func TestUnit_Node_StateTransitions(t *testing.T) { + t.Parallel() + + t.Run("setState", func(t *testing.T) { + n := newTestTransitionNode(t, nil) + assert.Equal(t, nodeStateUndialed, n.State()) + n.setState(nodeStateAlive) + assert.Equal(t, nodeStateAlive, n.State()) + n.setState(nodeStateUndialed) + assert.Equal(t, nodeStateUndialed, n.State()) + }) + + t.Run("transitionToAlive", func(t *testing.T) { + const destinationState = nodeStateAlive + allowedStates := []nodeState{nodeStateDialed, nodeStateInvalidChainID} + rpc := newMockNodeClient[types.ID, Head](t) + testTransition(t, rpc, testNode.transitionToAlive, destinationState, allowedStates...) + }) + + t.Run("transitionToInSync", func(t *testing.T) { + const destinationState = nodeStateAlive + allowedStates := []nodeState{nodeStateOutOfSync} + rpc := newMockNodeClient[types.ID, Head](t) + testTransition(t, rpc, testNode.transitionToInSync, destinationState, allowedStates...) + }) + t.Run("transitionToOutOfSync", func(t *testing.T) { + const destinationState = nodeStateOutOfSync + allowedStates := []nodeState{nodeStateAlive} + rpc := newMockNodeClient[types.ID, Head](t) + rpc.On("DisconnectAll").Once() + testTransition(t, rpc, testNode.transitionToOutOfSync, destinationState, allowedStates...) + }) + t.Run("transitionToUnreachable", func(t *testing.T) { + const destinationState = nodeStateUnreachable + allowedStates := []nodeState{nodeStateUndialed, nodeStateDialed, nodeStateAlive, nodeStateOutOfSync, nodeStateInvalidChainID} + rpc := newMockNodeClient[types.ID, Head](t) + rpc.On("DisconnectAll").Times(len(allowedStates)) + testTransition(t, rpc, testNode.transitionToUnreachable, destinationState, allowedStates...) + }) + t.Run("transitionToInvalidChain", func(t *testing.T) { + const destinationState = nodeStateInvalidChainID + allowedStates := []nodeState{nodeStateDialed, nodeStateOutOfSync} + rpc := newMockNodeClient[types.ID, Head](t) + rpc.On("DisconnectAll").Times(len(allowedStates)) + testTransition(t, rpc, testNode.transitionToInvalidChainID, destinationState, allowedStates...) + }) +} + +func testTransition(t *testing.T, rpc *mockNodeClient[types.ID, Head], transition func(node testNode, fn func()), destinationState nodeState, allowedStates ...nodeState) { + node := newTestTransitionNode(t, rpc) + for _, allowedState := range allowedStates { + m := new(fnMock) + node.setState(allowedState) + transition(node, m.Fn) + assert.Equal(t, destinationState, node.State(), "Expected node to successfully transition from %s to %s state", allowedState, destinationState) + m.AssertCalled(t) + } + // noop on attempt to transition from Closed state + m := new(fnMock) + node.setState(nodeStateClosed) + transition(node, m.Fn) + m.AssertNotCalled(t) + assert.Equal(t, nodeStateClosed, node.State(), "Expected node to remain in closed state on transition attempt") + + for _, nodeState := range allNodeStates { + if slices.Contains(allowedStates, nodeState) || nodeState == nodeStateClosed { + continue + } + + m := new(fnMock) + node.setState(nodeState) + assert.Panics(t, func() { + transition(node, m.Fn) + }, "Expected transition from `%s` to `%s` to panic", nodeState, destinationState) + m.AssertNotCalled(t) + assert.Equal(t, nodeState, node.State(), "Expected node to remain in initial state on invalid transition") + + } +} diff --git a/common/client/node_lifecycle.go b/common/client/node_lifecycle.go index 149c5f01a6d..4193560e296 100644 --- a/common/client/node_lifecycle.go +++ b/common/client/node_lifecycle.go @@ -60,6 +60,8 @@ const ( msgDegradedState = "Chainlink is now operating in a degraded state and urgent action is required to resolve the issue" ) +const rpcSubscriptionMethodNewHeads = "newHeads" + // Node is a FSM // Each state has a loop that goes with it, which monitors the node and moves it into another state as necessary. // Only one loop must run at a time. @@ -90,12 +92,14 @@ func (n *node[CHAIN_ID, HEAD, RPC]) aliveLoop() { lggr.Tracew("Alive loop starting", "nodeState", n.State()) headsC := make(chan HEAD) - sub, err := n.rpc.Subscribe(n.nodeCtx, headsC, "newHeads") + sub, err := n.rpc.Subscribe(n.nodeCtx, headsC, rpcSubscriptionMethodNewHeads) if err != nil { lggr.Errorw("Initial subscribe for heads failed", "nodeState", n.State()) n.declareUnreachable() return } + // TODO: nit fix. If multinode switches primary node before we set sub as AliveSub, sub will be closed and we'll + // falsely transition this node to unreachable state n.rpc.SetAliveLoopSub(sub) defer sub.Unsubscribe() @@ -289,7 +293,7 @@ func (n *node[CHAIN_ID, HEAD, RPC]) outOfSyncLoop(isOutOfSync func(num int64, td lggr.Tracew("Successfully subscribed to heads feed on out-of-sync RPC node", "nodeState", n.State()) ch := make(chan HEAD) - sub, err := n.rpc.Subscribe(n.nodeCtx, ch, "newHeads") + sub, err := n.rpc.Subscribe(n.nodeCtx, ch, rpcSubscriptionMethodNewHeads) if err != nil { lggr.Errorw("Failed to subscribe heads on out-of-sync RPC node", "nodeState", n.State(), "err", err) n.declareUnreachable() @@ -310,11 +314,11 @@ func (n *node[CHAIN_ID, HEAD, RPC]) outOfSyncLoop(isOutOfSync func(num int64, td n.setLatestReceived(head.BlockNumber(), head.BlockDifficulty()) if !isOutOfSync(head.BlockNumber(), head.BlockDifficulty()) { // back in-sync! flip back into alive loop - lggr.Infow(fmt.Sprintf("%s: %s. Node was out-of-sync for %s", msgInSync, n.String(), time.Since(outOfSyncAt)), "blockNumber", head.BlockNumber(), "totalDifficulty", "nodeState", n.State()) + lggr.Infow(fmt.Sprintf("%s: %s. Node was out-of-sync for %s", msgInSync, n.String(), time.Since(outOfSyncAt)), "blockNumber", head.BlockNumber(), "blockDifficulty", head.BlockDifficulty(), "nodeState", n.State()) n.declareInSync() return } - lggr.Debugw(msgReceivedBlock, "blockNumber", head.BlockNumber(), "totalDifficulty", "nodeState", n.State()) + lggr.Debugw(msgReceivedBlock, "blockNumber", head.BlockNumber(), "blockDifficulty", head.BlockDifficulty(), "nodeState", n.State()) case <-time.After(zombieNodeCheckInterval(n.noNewHeadsThreshold)): if n.nLiveNodes != nil { if l, _, _ := n.nLiveNodes(); l < 1 { diff --git a/common/client/node_lifecycle_test.go b/common/client/node_lifecycle_test.go new file mode 100644 index 00000000000..564c08bbdcc --- /dev/null +++ b/common/client/node_lifecycle_test.go @@ -0,0 +1,1070 @@ +package client + +import ( + "fmt" + "sync/atomic" + "testing" + + "github.com/cometbft/cometbft/libs/rand" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "go.uber.org/zap" + + "github.com/smartcontractkit/chainlink-relay/pkg/utils/tests" + + "github.com/smartcontractkit/chainlink/v2/common/types" + "github.com/smartcontractkit/chainlink/v2/common/types/mocks" + "github.com/smartcontractkit/chainlink/v2/core/logger" + "github.com/smartcontractkit/chainlink/v2/core/utils" +) + +func TestUnit_NodeLifecycle_aliveLoop(t *testing.T) { + t.Parallel() + + newDialedNode := func(t *testing.T, opts testNodeOpts) testNode { + node := newTestNode(t, opts) + opts.rpc.On("Close").Return(nil).Once() + + node.setState(nodeStateDialed) + return node + } + + t.Run("returns on closed", func(t *testing.T) { + node := newTestNode(t, testNodeOpts{}) + node.setState(nodeStateClosed) + node.wg.Add(1) + node.aliveLoop() + + }) + t.Run("if initial subscribe fails, transitions to unreachable", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + node := newDialedNode(t, testNodeOpts{ + rpc: rpc, + }) + defer func() { assert.NoError(t, node.close()) }() + + expectedError := errors.New("failed to subscribe to rpc") + rpc.On("Subscribe", mock.Anything, mock.Anything, rpcSubscriptionMethodNewHeads).Return(nil, expectedError).Once() + rpc.On("DisconnectAll").Once() + // might be called in unreachable loop + rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() + node.declareAlive() + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateUnreachable + }) + + }) + t.Run("if remote RPC connection is closed transitions to unreachable", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + + lggr, observedLogs := logger.TestLoggerObserved(t, zap.WarnLevel) + node := newDialedNode(t, testNodeOpts{ + rpc: rpc, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + + sub := mocks.NewSubscription(t) + errChan := make(chan error) + close(errChan) + sub.On("Err").Return((<-chan error)(errChan)).Once() + sub.On("Unsubscribe").Once() + rpc.On("Subscribe", mock.Anything, mock.Anything, rpcSubscriptionMethodNewHeads).Return(sub, nil).Once() + rpc.On("SetAliveLoopSub", sub).Once() + // disconnects all on transfer to unreachable + rpc.On("DisconnectAll").Once() + // might be called in unreachable loop + rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() + node.declareAlive() + tests.AssertLogEventually(t, observedLogs, "Subscription was terminated") + assert.Equal(t, nodeStateUnreachable, node.State()) + }) + + newSubscribedNode := func(t *testing.T, opts testNodeOpts) testNode { + sub := mocks.NewSubscription(t) + sub.On("Err").Return((<-chan error)(nil)) + sub.On("Unsubscribe").Once() + opts.rpc.On("Subscribe", mock.Anything, mock.Anything, rpcSubscriptionMethodNewHeads).Return(sub, nil).Once() + opts.rpc.On("SetAliveLoopSub", sub).Once() + return newDialedNode(t, opts) + } + t.Run("Stays alive and waits for signal", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + node := newSubscribedNode(t, testNodeOpts{ + config: testNodeConfig{}, + rpc: rpc, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + node.declareAlive() + tests.AssertLogEventually(t, observedLogs, "Head liveness checking disabled") + tests.AssertLogEventually(t, observedLogs, "Polling disabled") + assert.Equal(t, nodeStateAlive, node.State()) + }) + t.Run("stays alive while below pollFailureThreshold and resets counter on success", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + const pollFailureThreshold = 3 + node := newSubscribedNode(t, testNodeOpts{ + config: testNodeConfig{ + pollFailureThreshold: pollFailureThreshold, + pollInterval: tests.TestInterval, + }, + rpc: rpc, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + + pollError := errors.New("failed to get ClientVersion") + // 1. Return error several times, but below threshold + rpc.On("ClientVersion", mock.Anything).Return("", pollError).Run(func(_ mock.Arguments) { + // stays healthy while below threshold + assert.Equal(t, nodeStateAlive, node.State()) + }).Times(pollFailureThreshold - 1) + // 2. Successful call that is expected to reset counter + rpc.On("ClientVersion", mock.Anything).Return("client_version", nil).Once() + // 3. Return error. If we have not reset the timer, we'll transition to nonAliveState + rpc.On("ClientVersion", mock.Anything).Return("", pollError).Once() + // 4. Once during the call, check if node is alive + var ensuredAlive atomic.Bool + rpc.On("ClientVersion", mock.Anything).Return("client_version", nil).Run(func(_ mock.Arguments) { + if ensuredAlive.Load() { + return + } + ensuredAlive.Store(true) + assert.Equal(t, nodeStateAlive, node.State()) + }).Once() + // redundant call to stay in alive state + rpc.On("ClientVersion", mock.Anything).Return("client_version", nil) + node.declareAlive() + tests.AssertLogCountEventually(t, observedLogs, fmt.Sprintf("Poll failure, RPC endpoint %s failed to respond properly", node.String()), pollFailureThreshold) + tests.AssertLogCountEventually(t, observedLogs, "Version poll successful", 2) + assert.True(t, ensuredAlive.Load(), "expected to ensure that node was alive") + + }) + t.Run("with threshold poll failures, transitions to unreachable", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + const pollFailureThreshold = 3 + node := newSubscribedNode(t, testNodeOpts{ + config: testNodeConfig{ + pollFailureThreshold: pollFailureThreshold, + pollInterval: tests.TestInterval, + }, + rpc: rpc, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + pollError := errors.New("failed to get ClientVersion") + rpc.On("ClientVersion", mock.Anything).Return("", pollError) + // disconnects all on transfer to unreachable + rpc.On("DisconnectAll").Once() + // might be called in unreachable loop + rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() + node.declareAlive() + tests.AssertLogCountEventually(t, observedLogs, fmt.Sprintf("Poll failure, RPC endpoint %s failed to respond properly", node.String()), pollFailureThreshold) + tests.AssertEventually(t, func() bool { + return nodeStateUnreachable == node.State() + }) + }) + t.Run("with threshold poll failures, but we are the last node alive, forcibly keeps it alive", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + const pollFailureThreshold = 3 + node := newSubscribedNode(t, testNodeOpts{ + config: testNodeConfig{ + pollFailureThreshold: pollFailureThreshold, + pollInterval: tests.TestInterval, + }, + rpc: rpc, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + node.nLiveNodes = func() (count int, blockNumber int64, totalDifficulty *utils.Big) { + return 1, 20, utils.NewBigI(10) + } + pollError := errors.New("failed to get ClientVersion") + rpc.On("ClientVersion", mock.Anything).Return("", pollError) + node.declareAlive() + tests.AssertLogEventually(t, observedLogs, fmt.Sprintf("RPC endpoint failed to respond to %d consecutive polls", pollFailureThreshold)) + assert.Equal(t, nodeStateAlive, node.State()) + }) + t.Run("when behind more than SyncThreshold, transitions to out of sync", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + const syncThreshold = 10 + node := newSubscribedNode(t, testNodeOpts{ + config: testNodeConfig{ + pollInterval: tests.TestInterval, + syncThreshold: syncThreshold, + selectionMode: NodeSelectionModeRoundRobin, + }, + rpc: rpc, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + node.stateLatestBlockNumber = 20 + node.nLiveNodes = func() (count int, blockNumber int64, totalDifficulty *utils.Big) { + return 10, syncThreshold + node.stateLatestBlockNumber + 1, utils.NewBigI(10) + } + rpc.On("ClientVersion", mock.Anything).Return("", nil) + // tries to redial in outOfSync + rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Run(func(_ mock.Arguments) { + assert.Equal(t, nodeStateOutOfSync, node.State()) + }).Once() + // disconnects all on transfer to unreachable or outOfSync + rpc.On("DisconnectAll").Maybe() + // might be called in unreachable loop + rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() + node.declareAlive() + tests.AssertLogEventually(t, observedLogs, "Failed to dial out-of-sync RPC node") + }) + t.Run("when behind more than SyncThreshold but we are the last live node, forcibly stays alive", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + const syncThreshold = 10 + node := newSubscribedNode(t, testNodeOpts{ + config: testNodeConfig{ + pollInterval: tests.TestInterval, + syncThreshold: syncThreshold, + selectionMode: NodeSelectionModeRoundRobin, + }, + rpc: rpc, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + node.stateLatestBlockNumber = 20 + node.nLiveNodes = func() (count int, blockNumber int64, totalDifficulty *utils.Big) { + return 1, syncThreshold + node.stateLatestBlockNumber + 1, utils.NewBigI(10) + } + rpc.On("ClientVersion", mock.Anything).Return("", nil) + node.declareAlive() + tests.AssertLogEventually(t, observedLogs, fmt.Sprintf("RPC endpoint has fallen behind; %s %s", msgCannotDisable, msgDegradedState)) + }) + t.Run("when behind but SyncThreshold=0, stay alive", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + node := newSubscribedNode(t, testNodeOpts{ + config: testNodeConfig{ + pollInterval: tests.TestInterval, + syncThreshold: 0, + selectionMode: NodeSelectionModeRoundRobin, + }, + rpc: rpc, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + node.stateLatestBlockNumber = 20 + node.nLiveNodes = func() (count int, blockNumber int64, totalDifficulty *utils.Big) { + return 1, node.stateLatestBlockNumber + 100, utils.NewBigI(10) + } + rpc.On("ClientVersion", mock.Anything).Return("", nil) + node.declareAlive() + tests.AssertLogCountEventually(t, observedLogs, "Version poll successful", 2) + assert.Equal(t, nodeStateAlive, node.State()) + }) + + t.Run("when no new heads received for threshold, transitions to out of sync", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + node := newSubscribedNode(t, testNodeOpts{ + config: testNodeConfig{}, + noNewHeadsThreshold: tests.TestInterval, + rpc: rpc, + }) + defer func() { assert.NoError(t, node.close()) }() + // tries to redial in outOfSync + rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Run(func(_ mock.Arguments) { + assert.Equal(t, nodeStateOutOfSync, node.State()) + }).Once() + // disconnects all on transfer to unreachable or outOfSync + rpc.On("DisconnectAll").Maybe() + // might be called in unreachable loop + rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() + node.declareAlive() + tests.AssertEventually(t, func() bool { + // right after outOfSync we'll transfer to unreachable due to returned error on Dial + // we check that we were in out of sync state on first Dial call + return node.State() == nodeStateUnreachable + }) + }) + t.Run("when no new heads received for threshold but we are the last live node, forcibly stays alive", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + node := newSubscribedNode(t, testNodeOpts{ + config: testNodeConfig{}, + lggr: lggr, + noNewHeadsThreshold: tests.TestInterval, + rpc: rpc, + }) + defer func() { assert.NoError(t, node.close()) }() + node.nLiveNodes = func() (count int, blockNumber int64, totalDifficulty *utils.Big) { + return 1, 20, utils.NewBigI(10) + } + node.declareAlive() + tests.AssertLogEventually(t, observedLogs, fmt.Sprintf("RPC endpoint detected out of sync; %s %s", msgCannotDisable, msgDegradedState)) + assert.Equal(t, nodeStateAlive, node.State()) + }) + + t.Run("rpc closed head channel", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + sub := mocks.NewSubscription(t) + sub.On("Err").Return((<-chan error)(nil)) + sub.On("Unsubscribe").Once() + rpc.On("Subscribe", mock.Anything, mock.Anything, rpcSubscriptionMethodNewHeads).Run(func(args mock.Arguments) { + ch := args.Get(1).(chan<- Head) + close(ch) + }).Return(sub, nil).Once() + rpc.On("SetAliveLoopSub", sub).Once() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.ErrorLevel) + node := newDialedNode(t, testNodeOpts{ + lggr: lggr, + config: testNodeConfig{}, + noNewHeadsThreshold: tests.TestInterval, + rpc: rpc, + }) + defer func() { assert.NoError(t, node.close()) }() + // disconnects all on transfer to unreachable or outOfSync + rpc.On("DisconnectAll").Once() + // might be called in unreachable loop + rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() + node.declareAlive() + tests.AssertLogEventually(t, observedLogs, "Subscription channel unexpectedly closed") + assert.Equal(t, nodeStateUnreachable, node.State()) + + }) + t.Run("updates block number and difficulty on new head", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + sub := mocks.NewSubscription(t) + sub.On("Err").Return((<-chan error)(nil)) + sub.On("Unsubscribe").Once() + expectedBlockNumber := rand.Int64() + expectedDiff := utils.NewBigI(rand.Int64()) + rpc.On("Subscribe", mock.Anything, mock.Anything, rpcSubscriptionMethodNewHeads).Run(func(args mock.Arguments) { + ch := args.Get(1).(chan<- Head) + go writeHeads(t, ch, head{BlockNumber: expectedBlockNumber, BlockDifficulty: expectedDiff}) + }).Return(sub, nil).Once() + rpc.On("SetAliveLoopSub", sub).Once() + node := newDialedNode(t, testNodeOpts{ + config: testNodeConfig{}, + rpc: rpc, + }) + defer func() { assert.NoError(t, node.close()) }() + node.declareAlive() + tests.AssertEventually(t, func() bool { + state, block, diff := node.StateAndLatest() + return state == nodeStateAlive && block == expectedBlockNumber == diff.Equal(expectedDiff) + }) + }) +} + +type head struct { + BlockNumber int64 + BlockDifficulty *utils.Big +} + +func writeHeads(t *testing.T, ch chan<- Head, heads ...head) { + for _, head := range heads { + h := newMockHead(t) + h.On("BlockNumber").Return(head.BlockNumber) + h.On("BlockDifficulty").Return(head.BlockDifficulty) + select { + case ch <- h: + case <-tests.Context(t).Done(): + return + } + } +} + +func setupRPCForAliveLoop(t *testing.T, rpc *mockNodeClient[types.ID, Head]) { + rpc.On("Dial", mock.Anything).Return(nil).Maybe() + aliveSubscription := mocks.NewSubscription(t) + aliveSubscription.On("Err").Return((<-chan error)(nil)).Maybe() + aliveSubscription.On("Unsubscribe").Maybe() + rpc.On("Subscribe", mock.Anything, mock.Anything, rpcSubscriptionMethodNewHeads).Return(aliveSubscription, nil).Maybe() + rpc.On("SetAliveLoopSub", mock.Anything).Maybe() +} + +func TestUnit_NodeLifecycle_outOfSyncLoop(t *testing.T) { + t.Parallel() + + newAliveNode := func(t *testing.T, opts testNodeOpts) testNode { + node := newTestNode(t, opts) + opts.rpc.On("Close").Return(nil).Once() + // disconnects all on transfer to unreachable or outOfSync + opts.rpc.On("DisconnectAll") + node.setState(nodeStateAlive) + return node + } + + stubIsOutOfSync := func(num int64, td *utils.Big) bool { + return false + } + + t.Run("returns on closed", func(t *testing.T) { + t.Parallel() + node := newTestNode(t, testNodeOpts{}) + node.setState(nodeStateClosed) + node.wg.Add(1) + node.outOfSyncLoop(stubIsOutOfSync) + }) + t.Run("on old blocks stays outOfSync and returns on close", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.RandomID() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + node := newAliveNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("Dial", mock.Anything).Return(nil).Once() + // might be called multiple times + rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() + + outOfSyncSubscription := mocks.NewSubscription(t) + outOfSyncSubscription.On("Err").Return((<-chan error)(nil)) + outOfSyncSubscription.On("Unsubscribe").Once() + heads := []head{{BlockNumber: 7}, {BlockNumber: 11}, {BlockNumber: 13}} + rpc.On("Subscribe", mock.Anything, mock.Anything, rpcSubscriptionMethodNewHeads).Run(func(args mock.Arguments) { + ch := args.Get(1).(chan<- Head) + go writeHeads(t, ch, heads...) + }).Return(outOfSyncSubscription, nil).Once() + rpc.On("Dial", mock.Anything).Return(errors.New("failed to redial")).Maybe() + + node.declareOutOfSync(func(num int64, td *utils.Big) bool { + return true + }) + tests.AssertLogCountEventually(t, observedLogs, msgReceivedBlock, len(heads)) + assert.Equal(t, nodeStateOutOfSync, node.State()) + }) + t.Run("if initial dial fails, transitions to unreachable", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + node := newAliveNode(t, testNodeOpts{ + rpc: rpc, + }) + defer func() { assert.NoError(t, node.close()) }() + + expectedError := errors.New("failed to dial rpc") + // might be called again in unreachable loop, so no need to set once + rpc.On("Dial", mock.Anything).Return(expectedError) + node.declareOutOfSync(stubIsOutOfSync) + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateUnreachable + }) + }) + t.Run("if fail to get chainID, transitions to invalidChainID", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + node := newAliveNode(t, testNodeOpts{ + rpc: rpc, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("Dial", mock.Anything).Return(nil).Once() + expectedError := errors.New("failed to get chain ID") + // might be called multiple times + rpc.On("ChainID", mock.Anything).Return(types.NewIDFromInt(0), expectedError) + node.declareOutOfSync(stubIsOutOfSync) + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateInvalidChainID + }) + }) + t.Run("if chainID does not match, transitions to invalidChainID", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.NewIDFromInt(10) + rpcChainID := types.NewIDFromInt(11) + node := newAliveNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("Dial", mock.Anything).Return(nil).Once() + // might be called multiple times + rpc.On("ChainID", mock.Anything).Return(rpcChainID, nil) + node.declareOutOfSync(stubIsOutOfSync) + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateInvalidChainID + }) + }) + t.Run("if fails to subscribe, becomes unreachable", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.RandomID() + node := newAliveNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("Dial", mock.Anything).Return(nil).Once() + // might be called multiple times + rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() + expectedError := errors.New("failed to subscribe") + rpc.On("Subscribe", mock.Anything, mock.Anything, rpcSubscriptionMethodNewHeads).Return(nil, expectedError) + rpc.On("Dial", mock.Anything).Return(errors.New("failed to redial")).Maybe() + node.declareOutOfSync(stubIsOutOfSync) + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateUnreachable + }) + }) + t.Run("on subscription termination becomes unreachable", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.RandomID() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.ErrorLevel) + node := newAliveNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("Dial", mock.Anything).Return(nil).Once() + // might be called multiple times + rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() + + sub := mocks.NewSubscription(t) + errChan := make(chan error, 1) + errChan <- errors.New("subscription was terminate") + sub.On("Err").Return((<-chan error)(errChan)) + sub.On("Unsubscribe").Once() + rpc.On("Subscribe", mock.Anything, mock.Anything, rpcSubscriptionMethodNewHeads).Return(sub, nil).Once() + rpc.On("Dial", mock.Anything).Return(errors.New("failed to redial")).Maybe() + node.declareOutOfSync(stubIsOutOfSync) + tests.AssertLogEventually(t, observedLogs, "Subscription was terminated") + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateUnreachable + }) + }) + t.Run("becomes unreachable if head channel is closed", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.RandomID() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.ErrorLevel) + node := newAliveNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("Dial", mock.Anything).Return(nil).Once() + // might be called multiple times + rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() + + sub := mocks.NewSubscription(t) + sub.On("Err").Return((<-chan error)(nil)) + sub.On("Unsubscribe").Once() + rpc.On("Subscribe", mock.Anything, mock.Anything, rpcSubscriptionMethodNewHeads).Run(func(args mock.Arguments) { + ch := args.Get(1).(chan<- Head) + close(ch) + }).Return(sub, nil).Once() + rpc.On("Dial", mock.Anything).Return(errors.New("failed to redial")).Maybe() + node.declareOutOfSync(stubIsOutOfSync) + tests.AssertLogEventually(t, observedLogs, "Subscription channel unexpectedly closed") + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateUnreachable + }) + }) + + t.Run("becomes alive if it receives a newer head", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.RandomID() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + node := newAliveNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("Dial", mock.Anything).Return(nil).Once() + // might be called multiple times + rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() + + outOfSyncSubscription := mocks.NewSubscription(t) + outOfSyncSubscription.On("Err").Return((<-chan error)(nil)) + outOfSyncSubscription.On("Unsubscribe").Once() + const highestBlock = 1000 + rpc.On("Subscribe", mock.Anything, mock.Anything, rpcSubscriptionMethodNewHeads).Run(func(args mock.Arguments) { + ch := args.Get(1).(chan<- Head) + go writeHeads(t, ch, head{BlockNumber: highestBlock - 1}, head{BlockNumber: highestBlock}) + }).Return(outOfSyncSubscription, nil).Once() + + setupRPCForAliveLoop(t, rpc) + + node.declareOutOfSync(func(num int64, td *utils.Big) bool { + return num < highestBlock + }) + tests.AssertLogEventually(t, observedLogs, msgReceivedBlock) + tests.AssertLogEventually(t, observedLogs, msgInSync) + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateAlive + }) + }) + t.Run("becomes alive if there is no other nodes", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.RandomID() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + node := newAliveNode(t, testNodeOpts{ + noNewHeadsThreshold: tests.TestInterval, + rpc: rpc, + chainID: nodeChainID, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + node.nLiveNodes = func() (count int, blockNumber int64, totalDifficulty *utils.Big) { + return 0, 100, utils.NewBigI(200) + } + + rpc.On("Dial", mock.Anything).Return(nil).Once() + // might be called multiple times + rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil).Once() + + outOfSyncSubscription := mocks.NewSubscription(t) + outOfSyncSubscription.On("Err").Return((<-chan error)(nil)) + outOfSyncSubscription.On("Unsubscribe").Once() + rpc.On("Subscribe", mock.Anything, mock.Anything, rpcSubscriptionMethodNewHeads).Return(outOfSyncSubscription, nil).Once() + + setupRPCForAliveLoop(t, rpc) + + node.declareOutOfSync(stubIsOutOfSync) + tests.AssertLogEventually(t, observedLogs, "RPC endpoint is still out of sync, but there are no other available nodes. This RPC node will be forcibly moved back into the live pool in a degraded state") + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateAlive + }) + }) +} + +func TestUnit_NodeLifecycle_unreachableLoop(t *testing.T) { + t.Parallel() + + newAliveNode := func(t *testing.T, opts testNodeOpts) testNode { + node := newTestNode(t, opts) + opts.rpc.On("Close").Return(nil).Once() + // disconnects all on transfer to unreachable + opts.rpc.On("DisconnectAll") + + node.setState(nodeStateAlive) + return node + } + t.Run("returns on closed", func(t *testing.T) { + t.Parallel() + node := newTestNode(t, testNodeOpts{}) + node.setState(nodeStateClosed) + node.wg.Add(1) + node.unreachableLoop() + + }) + t.Run("on failed redial, keeps trying", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.RandomID() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + node := newAliveNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")) + node.declareUnreachable() + tests.AssertLogCountEventually(t, observedLogs, "Failed to redial RPC node; still unreachable", 2) + }) + t.Run("on failed chainID verification, keep trying", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.RandomID() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + node := newAliveNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("Dial", mock.Anything).Return(nil) + rpc.On("ChainID", mock.Anything).Run(func(_ mock.Arguments) { + assert.Equal(t, nodeStateDialed, node.State()) + }).Return(nodeChainID, errors.New("failed to get chain id")) + node.declareUnreachable() + tests.AssertLogCountEventually(t, observedLogs, "Failed to redial RPC node; verify failed", 2) + }) + t.Run("on chain ID mismatch transitions to invalidChainID", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.NewIDFromInt(10) + rpcChainID := types.NewIDFromInt(11) + node := newAliveNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("Dial", mock.Anything).Return(nil) + rpc.On("ChainID", mock.Anything).Return(rpcChainID, nil) + node.declareUnreachable() + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateInvalidChainID + }) + }) + t.Run("on valid chain ID becomes alive", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.RandomID() + node := newAliveNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("Dial", mock.Anything).Return(nil) + rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil) + + setupRPCForAliveLoop(t, rpc) + + node.declareUnreachable() + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateAlive + }) + }) +} + +func TestUnit_NodeLifecycle_invalidChainIDLoop(t *testing.T) { + t.Parallel() + newDialedNode := func(t *testing.T, opts testNodeOpts) testNode { + node := newTestNode(t, opts) + opts.rpc.On("Close").Return(nil).Once() + opts.rpc.On("DisconnectAll") + + node.setState(nodeStateDialed) + return node + } + t.Run("returns on closed", func(t *testing.T) { + t.Parallel() + node := newTestNode(t, testNodeOpts{}) + node.setState(nodeStateClosed) + node.wg.Add(1) + node.invalidChainIDLoop() + + }) + t.Run("on failed chainID call becomes unreachable", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.RandomID() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + node := newDialedNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("ChainID", mock.Anything).Return(nodeChainID, errors.New("failed to get chain id")) + // for unreachable loop + rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")).Maybe() + node.declareInvalidChainID() + tests.AssertLogEventually(t, observedLogs, "Unexpected error while verifying RPC node chain ID") + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateUnreachable + }) + }) + t.Run("on chainID mismatch keeps trying", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.NewIDFromInt(10) + rpcChainID := types.NewIDFromInt(11) + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + node := newDialedNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("ChainID", mock.Anything).Return(rpcChainID, nil) + node.declareInvalidChainID() + tests.AssertLogCountEventually(t, observedLogs, "Failed to verify RPC node; remote endpoint returned the wrong chain ID", 2) + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateInvalidChainID + }) + }) + t.Run("on valid chainID becomes alive", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.RandomID() + node := newDialedNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil) + + setupRPCForAliveLoop(t, rpc) + + node.declareInvalidChainID() + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateAlive + }) + }) +} + +func TestUnit_NodeLifecycle_start(t *testing.T) { + t.Parallel() + + newNode := func(t *testing.T, opts testNodeOpts) testNode { + node := newTestNode(t, opts) + opts.rpc.On("Close").Return(nil).Once() + + return node + } + t.Run("if fails on initial dial, becomes unreachable", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.RandomID() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + node := newNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("Dial", mock.Anything).Return(errors.New("failed to dial")) + // disconnects all on transfer to unreachable + rpc.On("DisconnectAll") + err := node.Start(tests.Context(t)) + assert.NoError(t, err) + tests.AssertLogEventually(t, observedLogs, "Dial failed: Node is unreachable") + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateUnreachable + }) + }) + t.Run("if chainID verification fails, becomes unreachable", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.RandomID() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.DebugLevel) + node := newNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + lggr: lggr, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("Dial", mock.Anything).Return(nil) + rpc.On("ChainID", mock.Anything).Run(func(_ mock.Arguments) { + assert.Equal(t, nodeStateDialed, node.State()) + }).Return(nodeChainID, errors.New("failed to get chain id")) + // disconnects all on transfer to unreachable + rpc.On("DisconnectAll") + err := node.Start(tests.Context(t)) + assert.NoError(t, err) + tests.AssertLogEventually(t, observedLogs, "Verify failed") + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateUnreachable + }) + }) + t.Run("on chain ID mismatch transitions to invalidChainID", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.NewIDFromInt(10) + rpcChainID := types.NewIDFromInt(11) + node := newNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("Dial", mock.Anything).Return(nil) + rpc.On("ChainID", mock.Anything).Return(rpcChainID, nil) + // disconnects all on transfer to unreachable + rpc.On("DisconnectAll") + err := node.Start(tests.Context(t)) + assert.NoError(t, err) + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateInvalidChainID + }) + }) + t.Run("on valid chain ID becomes alive", func(t *testing.T) { + t.Parallel() + rpc := newMockNodeClient[types.ID, Head](t) + nodeChainID := types.RandomID() + node := newNode(t, testNodeOpts{ + rpc: rpc, + chainID: nodeChainID, + }) + defer func() { assert.NoError(t, node.close()) }() + + rpc.On("Dial", mock.Anything).Return(nil) + rpc.On("ChainID", mock.Anything).Return(nodeChainID, nil) + + setupRPCForAliveLoop(t, rpc) + + err := node.Start(tests.Context(t)) + assert.NoError(t, err) + tests.AssertEventually(t, func() bool { + return node.State() == nodeStateAlive + }) + }) +} + +func TestUnit_NodeLifecycle_syncStatus(t *testing.T) { + t.Parallel() + t.Run("skip if nLiveNodes is not configured", func(t *testing.T) { + node := newTestNode(t, testNodeOpts{}) + outOfSync, liveNodes := node.syncStatus(0, nil) + assert.Equal(t, false, outOfSync) + assert.Equal(t, 0, liveNodes) + }) + t.Run("skip if syncThreshold is not configured", func(t *testing.T) { + node := newTestNode(t, testNodeOpts{}) + node.nLiveNodes = func() (count int, blockNumber int64, totalDifficulty *utils.Big) { + return + } + outOfSync, liveNodes := node.syncStatus(0, nil) + assert.Equal(t, false, outOfSync) + assert.Equal(t, 0, liveNodes) + }) + t.Run("panics on invalid selection mode", func(t *testing.T) { + node := newTestNode(t, testNodeOpts{ + config: testNodeConfig{syncThreshold: 1}, + }) + node.nLiveNodes = func() (count int, blockNumber int64, totalDifficulty *utils.Big) { + return + } + assert.Panics(t, func() { + _, _ = node.syncStatus(0, nil) + }) + }) + t.Run("block height selection mode", func(t *testing.T) { + const syncThreshold = 10 + const highestBlock = 1000 + const nodesNum = 20 + const totalDifficulty = 3000 + testCases := []struct { + name string + blockNumber int64 + outOfSync bool + }{ + { + name: "below threshold", + blockNumber: highestBlock - syncThreshold - 1, + outOfSync: true, + }, + { + name: "equal to threshold", + blockNumber: highestBlock - syncThreshold, + outOfSync: false, + }, + { + name: "equal to highest block", + blockNumber: highestBlock, + outOfSync: false, + }, + { + name: "higher than highest block", + blockNumber: highestBlock, + outOfSync: false, + }, + } + + for _, selectionMode := range []string{NodeSelectionModeHighestHead, NodeSelectionModeRoundRobin, NodeSelectionModePriorityLevel} { + node := newTestNode(t, testNodeOpts{ + config: testNodeConfig{ + syncThreshold: syncThreshold, + selectionMode: selectionMode, + }, + }) + node.nLiveNodes = func() (int, int64, *utils.Big) { + return nodesNum, highestBlock, utils.NewBigI(totalDifficulty) + } + for _, td := range []int64{totalDifficulty - syncThreshold - 1, totalDifficulty - syncThreshold, totalDifficulty, totalDifficulty + 1} { + for _, testCase := range testCases { + t.Run(fmt.Sprintf("%s: selectionMode: %s: total difficulty: %d", testCase.name, selectionMode, td), func(t *testing.T) { + outOfSync, liveNodes := node.syncStatus(testCase.blockNumber, utils.NewBigI(td)) + assert.Equal(t, nodesNum, liveNodes) + assert.Equal(t, testCase.outOfSync, outOfSync) + }) + } + } + } + + }) + t.Run("total difficulty selection mode", func(t *testing.T) { + const syncThreshold = 10 + const highestBlock = 1000 + const nodesNum = 20 + const totalDifficulty = 3000 + testCases := []struct { + name string + totalDifficulty int64 + outOfSync bool + }{ + { + name: "below threshold", + totalDifficulty: totalDifficulty - syncThreshold - 1, + outOfSync: true, + }, + { + name: "equal to threshold", + totalDifficulty: totalDifficulty - syncThreshold, + outOfSync: false, + }, + { + name: "equal to highest block", + totalDifficulty: totalDifficulty, + outOfSync: false, + }, + { + name: "higher than highest block", + totalDifficulty: totalDifficulty, + outOfSync: false, + }, + } + + node := newTestNode(t, testNodeOpts{ + config: testNodeConfig{ + syncThreshold: syncThreshold, + selectionMode: NodeSelectionModeTotalDifficulty, + }, + }) + node.nLiveNodes = func() (int, int64, *utils.Big) { + return nodesNum, highestBlock, utils.NewBigI(totalDifficulty) + } + for _, hb := range []int64{highestBlock - syncThreshold - 1, highestBlock - syncThreshold, highestBlock, highestBlock + 1} { + for _, testCase := range testCases { + t.Run(fmt.Sprintf("%s: selectionMode: %s: highest block: %d", testCase.name, NodeSelectionModeTotalDifficulty, hb), func(t *testing.T) { + outOfSync, liveNodes := node.syncStatus(hb, utils.NewBigI(testCase.totalDifficulty)) + assert.Equal(t, nodesNum, liveNodes) + assert.Equal(t, testCase.outOfSync, outOfSync) + }) + } + } + + }) +} diff --git a/common/client/node_selector.go b/common/client/node_selector.go new file mode 100644 index 00000000000..45604ebe8d9 --- /dev/null +++ b/common/client/node_selector.go @@ -0,0 +1,46 @@ +package client + +import ( + "fmt" + + "github.com/smartcontractkit/chainlink/v2/common/types" +) + +const ( + NodeSelectionModeHighestHead = "HighestHead" + NodeSelectionModeRoundRobin = "RoundRobin" + NodeSelectionModeTotalDifficulty = "TotalDifficulty" + NodeSelectionModePriorityLevel = "PriorityLevel" +) + +//go:generate mockery --quiet --name NodeSelector --structname mockNodeSelector --filename "mock_node_selector_test.go" --inpackage --case=underscore +type NodeSelector[ + CHAIN_ID types.ID, + HEAD Head, + RPC NodeClient[CHAIN_ID, HEAD], +] interface { + // Select returns a Node, or nil if none can be selected. + // Implementation must be thread-safe. + Select() Node[CHAIN_ID, HEAD, RPC] + // Name returns the strategy name, e.g. "HighestHead" or "RoundRobin" + Name() string +} + +func newNodeSelector[ + CHAIN_ID types.ID, + HEAD Head, + RPC NodeClient[CHAIN_ID, HEAD], +](selectionMode string, nodes []Node[CHAIN_ID, HEAD, RPC]) NodeSelector[CHAIN_ID, HEAD, RPC] { + switch selectionMode { + case NodeSelectionModeHighestHead: + return NewHighestHeadNodeSelector[CHAIN_ID, HEAD, RPC](nodes) + case NodeSelectionModeRoundRobin: + return NewRoundRobinSelector[CHAIN_ID, HEAD, RPC](nodes) + case NodeSelectionModeTotalDifficulty: + return NewTotalDifficultyNodeSelector[CHAIN_ID, HEAD, RPC](nodes) + case NodeSelectionModePriorityLevel: + return NewPriorityLevelNodeSelector[CHAIN_ID, HEAD, RPC](nodes) + default: + panic(fmt.Sprintf("unsupported NodeSelectionMode: %s", selectionMode)) + } +} diff --git a/common/client/node_selector_highest_head_test.go b/common/client/node_selector_highest_head_test.go new file mode 100644 index 00000000000..6e47bbedcae --- /dev/null +++ b/common/client/node_selector_highest_head_test.go @@ -0,0 +1,176 @@ +package client + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/smartcontractkit/chainlink/v2/common/types" +) + +func TestHighestHeadNodeSelectorName(t *testing.T) { + selector := newNodeSelector[types.ID, Head, NodeClient[types.ID, Head]](NodeSelectionModeHighestHead, nil) + assert.Equal(t, selector.Name(), NodeSelectionModeHighestHead) +} + +func TestHighestHeadNodeSelector(t *testing.T) { + t.Parallel() + + type nodeClient NodeClient[types.ID, Head] + + var nodes []Node[types.ID, Head, nodeClient] + + for i := 0; i < 3; i++ { + node := newMockNode[types.ID, Head, nodeClient](t) + if i == 0 { + // first node is out of sync + node.On("StateAndLatest").Return(nodeStateOutOfSync, int64(-1), nil) + } else if i == 1 { + // second node is alive, LatestReceivedBlockNumber = 1 + node.On("StateAndLatest").Return(nodeStateAlive, int64(1), nil) + } else { + // third node is alive, LatestReceivedBlockNumber = 2 (best node) + node.On("StateAndLatest").Return(nodeStateAlive, int64(2), nil) + } + node.On("Order").Maybe().Return(int32(1)) + nodes = append(nodes, node) + } + + selector := newNodeSelector[types.ID, Head, nodeClient](NodeSelectionModeHighestHead, nodes) + assert.Same(t, nodes[2], selector.Select()) + + t.Run("stick to the same node", func(t *testing.T) { + node := newMockNode[types.ID, Head, nodeClient](t) + // fourth node is alive, LatestReceivedBlockNumber = 2 (same as 3rd) + node.On("StateAndLatest").Return(nodeStateAlive, int64(2), nil) + node.On("Order").Return(int32(1)) + nodes = append(nodes, node) + + selector := newNodeSelector(NodeSelectionModeHighestHead, nodes) + assert.Same(t, nodes[2], selector.Select()) + }) + + t.Run("another best node", func(t *testing.T) { + node := newMockNode[types.ID, Head, nodeClient](t) + // fifth node is alive, LatestReceivedBlockNumber = 3 (better than 3rd and 4th) + node.On("StateAndLatest").Return(nodeStateAlive, int64(3), nil) + node.On("Order").Return(int32(1)) + nodes = append(nodes, node) + + selector := newNodeSelector(NodeSelectionModeHighestHead, nodes) + assert.Same(t, nodes[4], selector.Select()) + }) + + t.Run("nodes never update latest block number", func(t *testing.T) { + node1 := newMockNode[types.ID, Head, nodeClient](t) + node1.On("StateAndLatest").Return(nodeStateAlive, int64(-1), nil) + node1.On("Order").Return(int32(1)) + node2 := newMockNode[types.ID, Head, nodeClient](t) + node2.On("StateAndLatest").Return(nodeStateAlive, int64(-1), nil) + node2.On("Order").Return(int32(1)) + selector := newNodeSelector(NodeSelectionModeHighestHead, []Node[types.ID, Head, nodeClient]{node1, node2}) + assert.Same(t, node1, selector.Select()) + }) +} + +func TestHighestHeadNodeSelector_None(t *testing.T) { + t.Parallel() + + type nodeClient NodeClient[types.ID, Head] + var nodes []Node[types.ID, Head, nodeClient] + + for i := 0; i < 3; i++ { + node := newMockNode[types.ID, Head, nodeClient](t) + if i == 0 { + // first node is out of sync + node.On("StateAndLatest").Return(nodeStateOutOfSync, int64(-1), nil) + } else { + // others are unreachable + node.On("StateAndLatest").Return(nodeStateUnreachable, int64(1), nil) + } + nodes = append(nodes, node) + } + + selector := newNodeSelector(NodeSelectionModeHighestHead, nodes) + assert.Nil(t, selector.Select()) +} + +func TestHighestHeadNodeSelectorWithOrder(t *testing.T) { + t.Parallel() + + type nodeClient NodeClient[types.ID, Head] + var nodes []Node[types.ID, Head, nodeClient] + + t.Run("same head and order", func(t *testing.T) { + for i := 0; i < 3; i++ { + node := newMockNode[types.ID, Head, nodeClient](t) + node.On("StateAndLatest").Return(nodeStateAlive, int64(1), nil) + node.On("Order").Return(int32(2)) + nodes = append(nodes, node) + } + selector := newNodeSelector(NodeSelectionModeHighestHead, nodes) + //Should select the first node because all things are equal + assert.Same(t, nodes[0], selector.Select()) + }) + + t.Run("same head but different order", func(t *testing.T) { + node1 := newMockNode[types.ID, Head, nodeClient](t) + node1.On("StateAndLatest").Return(nodeStateAlive, int64(3), nil) + node1.On("Order").Return(int32(3)) + + node2 := newMockNode[types.ID, Head, nodeClient](t) + node2.On("StateAndLatest").Return(nodeStateAlive, int64(3), nil) + node2.On("Order").Return(int32(1)) + + node3 := newMockNode[types.ID, Head, nodeClient](t) + node3.On("StateAndLatest").Return(nodeStateAlive, int64(3), nil) + node3.On("Order").Return(int32(2)) + + nodes := []Node[types.ID, Head, nodeClient]{node1, node2, node3} + selector := newNodeSelector(NodeSelectionModeHighestHead, nodes) + //Should select the second node as it has the highest priority + assert.Same(t, nodes[1], selector.Select()) + }) + + t.Run("different head but same order", func(t *testing.T) { + node1 := newMockNode[types.ID, Head, nodeClient](t) + node1.On("StateAndLatest").Return(nodeStateAlive, int64(1), nil) + node1.On("Order").Maybe().Return(int32(3)) + + node2 := newMockNode[types.ID, Head, nodeClient](t) + node2.On("StateAndLatest").Return(nodeStateAlive, int64(2), nil) + node2.On("Order").Maybe().Return(int32(3)) + + node3 := newMockNode[types.ID, Head, nodeClient](t) + node3.On("StateAndLatest").Return(nodeStateAlive, int64(3), nil) + node3.On("Order").Return(int32(3)) + + nodes := []Node[types.ID, Head, nodeClient]{node1, node2, node3} + selector := newNodeSelector(NodeSelectionModeHighestHead, nodes) + //Should select the third node as it has the highest head + assert.Same(t, nodes[2], selector.Select()) + }) + + t.Run("different head and different order", func(t *testing.T) { + node1 := newMockNode[types.ID, Head, nodeClient](t) + node1.On("StateAndLatest").Return(nodeStateAlive, int64(10), nil) + node1.On("Order").Maybe().Return(int32(3)) + + node2 := newMockNode[types.ID, Head, nodeClient](t) + node2.On("StateAndLatest").Return(nodeStateAlive, int64(11), nil) + node2.On("Order").Maybe().Return(int32(4)) + + node3 := newMockNode[types.ID, Head, nodeClient](t) + node3.On("StateAndLatest").Return(nodeStateAlive, int64(11), nil) + node3.On("Order").Maybe().Return(int32(3)) + + node4 := newMockNode[types.ID, Head, nodeClient](t) + node4.On("StateAndLatest").Return(nodeStateAlive, int64(10), nil) + node4.On("Order").Maybe().Return(int32(1)) + + nodes := []Node[types.ID, Head, nodeClient]{node1, node2, node3, node4} + selector := newNodeSelector(NodeSelectionModeHighestHead, nodes) + //Should select the third node as it has the highest head and will win the priority tie-breaker + assert.Same(t, nodes[2], selector.Select()) + }) +} diff --git a/common/client/node_selector_priority_level_test.go b/common/client/node_selector_priority_level_test.go new file mode 100644 index 00000000000..ac84645e91c --- /dev/null +++ b/common/client/node_selector_priority_level_test.go @@ -0,0 +1,88 @@ +package client + +import ( + "testing" + + "github.com/smartcontractkit/chainlink/v2/common/types" + + "github.com/stretchr/testify/assert" +) + +func TestPriorityLevelNodeSelectorName(t *testing.T) { + selector := newNodeSelector[types.ID, Head, NodeClient[types.ID, Head]](NodeSelectionModePriorityLevel, nil) + assert.Equal(t, selector.Name(), NodeSelectionModePriorityLevel) +} + +func TestPriorityLevelNodeSelector(t *testing.T) { + t.Parallel() + + type nodeClient NodeClient[types.ID, Head] + var nodes []Node[types.ID, Head, nodeClient] + n1 := newMockNode[types.ID, Head, nodeClient](t) + n1.On("State").Return(nodeStateAlive) + n1.On("Order").Return(int32(1)) + + n2 := newMockNode[types.ID, Head, nodeClient](t) + n2.On("State").Return(nodeStateAlive) + n2.On("Order").Return(int32(1)) + + n3 := newMockNode[types.ID, Head, nodeClient](t) + n3.On("State").Return(nodeStateAlive) + n3.On("Order").Return(int32(1)) + + nodes = append(nodes, n1, n2, n3) + selector := newNodeSelector(NodeSelectionModePriorityLevel, nodes) + assert.Same(t, nodes[0], selector.Select()) + assert.Same(t, nodes[1], selector.Select()) + assert.Same(t, nodes[2], selector.Select()) + assert.Same(t, nodes[0], selector.Select()) + assert.Same(t, nodes[1], selector.Select()) + assert.Same(t, nodes[2], selector.Select()) +} + +func TestPriorityLevelNodeSelector_None(t *testing.T) { + t.Parallel() + + type nodeClient NodeClient[types.ID, Head] + var nodes []Node[types.ID, Head, nodeClient] + + for i := 0; i < 3; i++ { + node := newMockNode[types.ID, Head, nodeClient](t) + if i == 0 { + // first node is out of sync + node.On("State").Return(nodeStateOutOfSync) + node.On("Order").Return(int32(1)) + } else { + // others are unreachable + node.On("State").Return(nodeStateUnreachable) + node.On("Order").Return(int32(1)) + } + nodes = append(nodes, node) + } + + selector := newNodeSelector(NodeSelectionModePriorityLevel, nodes) + assert.Nil(t, selector.Select()) +} + +func TestPriorityLevelNodeSelector_DifferentOrder(t *testing.T) { + t.Parallel() + + type nodeClient NodeClient[types.ID, Head] + var nodes []Node[types.ID, Head, nodeClient] + n1 := newMockNode[types.ID, Head, nodeClient](t) + n1.On("State").Return(nodeStateAlive) + n1.On("Order").Return(int32(1)) + + n2 := newMockNode[types.ID, Head, nodeClient](t) + n2.On("State").Return(nodeStateAlive) + n2.On("Order").Return(int32(2)) + + n3 := newMockNode[types.ID, Head, nodeClient](t) + n3.On("State").Return(nodeStateAlive) + n3.On("Order").Return(int32(3)) + + nodes = append(nodes, n1, n2, n3) + selector := newNodeSelector(NodeSelectionModePriorityLevel, nodes) + assert.Same(t, nodes[0], selector.Select()) + assert.Same(t, nodes[0], selector.Select()) +} diff --git a/common/client/node_selector_round_robin_test.go b/common/client/node_selector_round_robin_test.go new file mode 100644 index 00000000000..e5078d858f1 --- /dev/null +++ b/common/client/node_selector_round_robin_test.go @@ -0,0 +1,61 @@ +package client + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/smartcontractkit/chainlink/v2/common/types" +) + +func TestRoundRobinNodeSelectorName(t *testing.T) { + selector := newNodeSelector[types.ID, Head, NodeClient[types.ID, Head]](NodeSelectionModeRoundRobin, nil) + assert.Equal(t, selector.Name(), NodeSelectionModeRoundRobin) +} + +func TestRoundRobinNodeSelector(t *testing.T) { + t.Parallel() + + type nodeClient NodeClient[types.ID, Head] + var nodes []Node[types.ID, Head, nodeClient] + + for i := 0; i < 3; i++ { + node := newMockNode[types.ID, Head, nodeClient](t) + if i == 0 { + // first node is out of sync + node.On("State").Return(nodeStateOutOfSync) + } else { + // second & third nodes are alive + node.On("State").Return(nodeStateAlive) + } + nodes = append(nodes, node) + } + + selector := newNodeSelector(NodeSelectionModeRoundRobin, nodes) + assert.Same(t, nodes[1], selector.Select()) + assert.Same(t, nodes[2], selector.Select()) + assert.Same(t, nodes[1], selector.Select()) + assert.Same(t, nodes[2], selector.Select()) +} + +func TestRoundRobinNodeSelector_None(t *testing.T) { + t.Parallel() + + type nodeClient NodeClient[types.ID, Head] + var nodes []Node[types.ID, Head, nodeClient] + + for i := 0; i < 3; i++ { + node := newMockNode[types.ID, Head, nodeClient](t) + if i == 0 { + // first node is out of sync + node.On("State").Return(nodeStateOutOfSync) + } else { + // others are unreachable + node.On("State").Return(nodeStateUnreachable) + } + nodes = append(nodes, node) + } + + selector := newNodeSelector(NodeSelectionModeRoundRobin, nodes) + assert.Nil(t, selector.Select()) +} diff --git a/common/client/node_selector_test.go b/common/client/node_selector_test.go new file mode 100644 index 00000000000..226cb67168d --- /dev/null +++ b/common/client/node_selector_test.go @@ -0,0 +1,18 @@ +package client + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/smartcontractkit/chainlink/v2/common/types" +) + +func TestNodeSelector(t *testing.T) { + // rest of the tests are located in specific node selectors tests + t.Run("panics on unknown type", func(t *testing.T) { + assert.Panics(t, func() { + _ = newNodeSelector[types.ID, Head, NodeClient[types.ID, Head]]("unknown", nil) + }) + }) +} diff --git a/common/client/node_selector_total_difficulty_test.go b/common/client/node_selector_total_difficulty_test.go new file mode 100644 index 00000000000..4eecb859db9 --- /dev/null +++ b/common/client/node_selector_total_difficulty_test.go @@ -0,0 +1,178 @@ +package client + +import ( + "testing" + + "github.com/smartcontractkit/chainlink/v2/common/types" + "github.com/smartcontractkit/chainlink/v2/core/utils" + + "github.com/stretchr/testify/assert" +) + +func TestTotalDifficultyNodeSelectorName(t *testing.T) { + selector := newNodeSelector[types.ID, Head, NodeClient[types.ID, Head]](NodeSelectionModeTotalDifficulty, nil) + assert.Equal(t, selector.Name(), NodeSelectionModeTotalDifficulty) +} + +func TestTotalDifficultyNodeSelector(t *testing.T) { + t.Parallel() + + type nodeClient NodeClient[types.ID, Head] + var nodes []Node[types.ID, Head, nodeClient] + + for i := 0; i < 3; i++ { + node := newMockNode[types.ID, Head, nodeClient](t) + if i == 0 { + // first node is out of sync + node.On("StateAndLatest").Return(nodeStateOutOfSync, int64(-1), nil) + } else if i == 1 { + // second node is alive + node.On("StateAndLatest").Return(nodeStateAlive, int64(1), utils.NewBigI(7)) + } else { + // third node is alive and best + node.On("StateAndLatest").Return(nodeStateAlive, int64(2), utils.NewBigI(8)) + } + node.On("Order").Maybe().Return(int32(1)) + nodes = append(nodes, node) + } + + selector := newNodeSelector(NodeSelectionModeTotalDifficulty, nodes) + assert.Same(t, nodes[2], selector.Select()) + + t.Run("stick to the same node", func(t *testing.T) { + node := newMockNode[types.ID, Head, nodeClient](t) + // fourth node is alive (same as 3rd) + node.On("StateAndLatest").Return(nodeStateAlive, int64(2), utils.NewBigI(8)) + node.On("Order").Maybe().Return(int32(1)) + nodes = append(nodes, node) + + selector := newNodeSelector(NodeSelectionModeTotalDifficulty, nodes) + assert.Same(t, nodes[2], selector.Select()) + }) + + t.Run("another best node", func(t *testing.T) { + node := newMockNode[types.ID, Head, nodeClient](t) + // fifth node is alive (better than 3rd and 4th) + node.On("StateAndLatest").Return(nodeStateAlive, int64(3), utils.NewBigI(11)) + node.On("Order").Maybe().Return(int32(1)) + nodes = append(nodes, node) + + selector := newNodeSelector(NodeSelectionModeTotalDifficulty, nodes) + assert.Same(t, nodes[4], selector.Select()) + }) + + t.Run("nodes never update latest block number", func(t *testing.T) { + node1 := newMockNode[types.ID, Head, nodeClient](t) + node1.On("StateAndLatest").Return(nodeStateAlive, int64(-1), nil) + node1.On("Order").Maybe().Return(int32(1)) + node2 := newMockNode[types.ID, Head, nodeClient](t) + node2.On("StateAndLatest").Return(nodeStateAlive, int64(-1), nil) + node2.On("Order").Maybe().Return(int32(1)) + nodes := []Node[types.ID, Head, nodeClient]{node1, node2} + + selector := newNodeSelector(NodeSelectionModeTotalDifficulty, nodes) + assert.Same(t, node1, selector.Select()) + }) +} + +func TestTotalDifficultyNodeSelector_None(t *testing.T) { + t.Parallel() + + type nodeClient NodeClient[types.ID, Head] + var nodes []Node[types.ID, Head, nodeClient] + + for i := 0; i < 3; i++ { + node := newMockNode[types.ID, Head, nodeClient](t) + if i == 0 { + // first node is out of sync + node.On("StateAndLatest").Return(nodeStateOutOfSync, int64(-1), nil) + } else { + // others are unreachable + node.On("StateAndLatest").Return(nodeStateUnreachable, int64(1), utils.NewBigI(7)) + } + nodes = append(nodes, node) + } + + selector := newNodeSelector(NodeSelectionModeTotalDifficulty, nodes) + assert.Nil(t, selector.Select()) +} + +func TestTotalDifficultyNodeSelectorWithOrder(t *testing.T) { + t.Parallel() + + type nodeClient NodeClient[types.ID, Head] + var nodes []Node[types.ID, Head, nodeClient] + + t.Run("same td and order", func(t *testing.T) { + for i := 0; i < 3; i++ { + node := newMockNode[types.ID, Head, nodeClient](t) + node.On("StateAndLatest").Return(nodeStateAlive, int64(1), utils.NewBigI(10)) + node.On("Order").Return(int32(2)) + nodes = append(nodes, node) + } + selector := newNodeSelector(NodeSelectionModeTotalDifficulty, nodes) + //Should select the first node because all things are equal + assert.Same(t, nodes[0], selector.Select()) + }) + + t.Run("same td but different order", func(t *testing.T) { + node1 := newMockNode[types.ID, Head, nodeClient](t) + node1.On("StateAndLatest").Return(nodeStateAlive, int64(3), utils.NewBigI(10)) + node1.On("Order").Return(int32(3)) + + node2 := newMockNode[types.ID, Head, nodeClient](t) + node2.On("StateAndLatest").Return(nodeStateAlive, int64(3), utils.NewBigI(10)) + node2.On("Order").Return(int32(1)) + + node3 := newMockNode[types.ID, Head, nodeClient](t) + node3.On("StateAndLatest").Return(nodeStateAlive, int64(3), utils.NewBigI(10)) + node3.On("Order").Return(int32(2)) + + nodes := []Node[types.ID, Head, nodeClient]{node1, node2, node3} + selector := newNodeSelector(NodeSelectionModeTotalDifficulty, nodes) + //Should select the second node as it has the highest priority + assert.Same(t, nodes[1], selector.Select()) + }) + + t.Run("different td but same order", func(t *testing.T) { + node1 := newMockNode[types.ID, Head, nodeClient](t) + node1.On("StateAndLatest").Return(nodeStateAlive, int64(1), utils.NewBigI(10)) + node1.On("Order").Maybe().Return(int32(3)) + + node2 := newMockNode[types.ID, Head, nodeClient](t) + node2.On("StateAndLatest").Return(nodeStateAlive, int64(1), utils.NewBigI(11)) + node2.On("Order").Maybe().Return(int32(3)) + + node3 := newMockNode[types.ID, Head, nodeClient](t) + node3.On("StateAndLatest").Return(nodeStateAlive, int64(1), utils.NewBigI(12)) + node3.On("Order").Return(int32(3)) + + nodes := []Node[types.ID, Head, nodeClient]{node1, node2, node3} + selector := newNodeSelector(NodeSelectionModeTotalDifficulty, nodes) + //Should select the third node as it has the highest td + assert.Same(t, nodes[2], selector.Select()) + }) + + t.Run("different head and different order", func(t *testing.T) { + node1 := newMockNode[types.ID, Head, nodeClient](t) + node1.On("StateAndLatest").Return(nodeStateAlive, int64(1), utils.NewBigI(100)) + node1.On("Order").Maybe().Return(int32(4)) + + node2 := newMockNode[types.ID, Head, nodeClient](t) + node2.On("StateAndLatest").Return(nodeStateAlive, int64(1), utils.NewBigI(110)) + node2.On("Order").Maybe().Return(int32(5)) + + node3 := newMockNode[types.ID, Head, nodeClient](t) + node3.On("StateAndLatest").Return(nodeStateAlive, int64(1), utils.NewBigI(110)) + node3.On("Order").Maybe().Return(int32(1)) + + node4 := newMockNode[types.ID, Head, nodeClient](t) + node4.On("StateAndLatest").Return(nodeStateAlive, int64(1), utils.NewBigI(105)) + node4.On("Order").Maybe().Return(int32(2)) + + nodes := []Node[types.ID, Head, nodeClient]{node1, node2, node3, node4} + selector := newNodeSelector(NodeSelectionModeTotalDifficulty, nodes) + //Should select the third node as it has the highest td and will win the priority tie-breaker + assert.Same(t, nodes[2], selector.Select()) + }) +} diff --git a/common/client/node_test.go b/common/client/node_test.go new file mode 100644 index 00000000000..0438e11e612 --- /dev/null +++ b/common/client/node_test.go @@ -0,0 +1,80 @@ +package client + +import ( + "net/url" + "testing" + "time" + + "github.com/smartcontractkit/chainlink/v2/common/types" + "github.com/smartcontractkit/chainlink/v2/core/logger" +) + +type testNodeConfig struct { + pollFailureThreshold uint32 + pollInterval time.Duration + selectionMode string + syncThreshold uint32 +} + +func (n testNodeConfig) PollFailureThreshold() uint32 { + return n.pollFailureThreshold +} + +func (n testNodeConfig) PollInterval() time.Duration { + return n.pollInterval +} + +func (n testNodeConfig) SelectionMode() string { + return n.selectionMode +} + +func (n testNodeConfig) SyncThreshold() uint32 { + return n.syncThreshold +} + +type testNode struct { + *node[types.ID, Head, NodeClient[types.ID, Head]] +} + +type testNodeOpts struct { + config testNodeConfig + noNewHeadsThreshold time.Duration + lggr logger.Logger + wsuri url.URL + httpuri *url.URL + name string + id int32 + chainID types.ID + nodeOrder int32 + rpc *mockNodeClient[types.ID, Head] + chainFamily string +} + +func newTestNode(t *testing.T, opts testNodeOpts) testNode { + if opts.lggr == nil { + opts.lggr = logger.TestLogger(t) + } + + if opts.name == "" { + opts.name = "tes node" + } + + if opts.chainFamily == "" { + opts.chainFamily = "test node chain family" + } + + if opts.chainID == nil { + opts.chainID = types.RandomID() + } + + if opts.id == 0 { + opts.id = 42 + } + + nodeI := NewNode[types.ID, Head, NodeClient[types.ID, Head]](opts.config, opts.noNewHeadsThreshold, opts.lggr, + opts.wsuri, opts.httpuri, opts.name, opts.id, opts.chainID, opts.nodeOrder, opts.rpc, opts.chainFamily) + + return testNode{ + nodeI.(*node[types.ID, Head, NodeClient[types.ID, Head]]), + } +} diff --git a/common/client/send_only_node.go b/common/client/send_only_node.go index 0051fb014d9..fa793a826a6 100644 --- a/common/client/send_only_node.go +++ b/common/client/send_only_node.go @@ -13,6 +13,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/utils" ) +//go:generate mockery --quiet --name sendOnlyClient --structname mockSendOnlyClient --filename "mock_send_only_client_test.go" --inpackage --case=underscore type sendOnlyClient[ CHAIN_ID types.ID, ] interface { @@ -22,6 +23,8 @@ type sendOnlyClient[ } // SendOnlyNode represents one node used as a sendonly +// +//go:generate mockery --quiet --name SendOnlyNode --structname mockSendOnlyNode --filename "mock_send_only_node_test.go" --inpackage --case=underscore type SendOnlyNode[ CHAIN_ID types.ID, RPC sendOnlyClient[CHAIN_ID], @@ -143,6 +146,7 @@ func (s *sendOnlyNode[CHAIN_ID, RPC]) start(startCtx context.Context) { func (s *sendOnlyNode[CHAIN_ID, RPC]) Close() error { return s.StopOnce(s.name, func() error { s.rpc.Close() + close(s.chStop) s.wg.Wait() s.setState(nodeStateClosed) return nil diff --git a/common/client/send_only_node_test.go b/common/client/send_only_node_test.go new file mode 100644 index 00000000000..bfe55153656 --- /dev/null +++ b/common/client/send_only_node_test.go @@ -0,0 +1,139 @@ +package client + +import ( + "fmt" + "net/url" + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/smartcontractkit/chainlink-relay/pkg/utils/tests" + + "github.com/smartcontractkit/chainlink/v2/common/types" + "github.com/smartcontractkit/chainlink/v2/core/logger" +) + +func TestNewSendOnlyNode(t *testing.T) { + t.Parallel() + + urlFormat := "http://user:%s@testurl.com" + password := "pass" + u, err := url.Parse(fmt.Sprintf(urlFormat, password)) + require.NoError(t, err) + redacted := fmt.Sprintf(urlFormat, "xxxxx") + lggr := logger.TestLogger(t) + name := "TestNewSendOnlyNode" + chainID := types.RandomID() + client := newMockSendOnlyClient[types.ID](t) + + node := NewSendOnlyNode(lggr, *u, name, chainID, client) + assert.NotNil(t, node) + + // Must contain name & url with redacted password + assert.Contains(t, node.String(), fmt.Sprintf("%s:%s", name, redacted)) + assert.Equal(t, node.ConfiguredChainID(), chainID) +} + +func TestStartSendOnlyNode(t *testing.T) { + t.Parallel() + t.Run("becomes unusable if initial dial fails", func(t *testing.T) { + t.Parallel() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.WarnLevel) + client := newMockSendOnlyClient[types.ID](t) + client.On("Close").Once() + expectedError := errors.New("some http error") + client.On("DialHTTP").Return(expectedError).Once() + s := NewSendOnlyNode(lggr, url.URL{}, t.Name(), types.RandomID(), client) + + defer func() { assert.NoError(t, s.Close()) }() + err := s.Start(tests.Context(t)) + require.NoError(t, err) + + assert.Equal(t, nodeStateUnusable, s.State()) + tests.RequireLogMessage(t, observedLogs, "Dial failed: SendOnly Node is unusable") + }) + t.Run("Default ChainID(0) produces warn and skips checks", func(t *testing.T) { + t.Parallel() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.WarnLevel) + client := newMockSendOnlyClient[types.ID](t) + client.On("Close").Once() + client.On("DialHTTP").Return(nil).Once() + s := NewSendOnlyNode(lggr, url.URL{}, t.Name(), types.NewIDFromInt(0), client) + + defer func() { assert.NoError(t, s.Close()) }() + err := s.Start(tests.Context(t)) + require.NoError(t, err) + + assert.Equal(t, nodeStateAlive, s.State()) + tests.RequireLogMessage(t, observedLogs, "sendonly rpc ChainID verification skipped") + }) + t.Run("Can recover from chainID verification failure", func(t *testing.T) { + t.Parallel() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.WarnLevel) + client := newMockSendOnlyClient[types.ID](t) + client.On("Close").Once() + client.On("DialHTTP").Return(nil) + expectedError := errors.New("failed to get chain ID") + chainID := types.RandomID() + const failuresCount = 2 + client.On("ChainID", mock.Anything).Return(types.RandomID(), expectedError).Times(failuresCount) + client.On("ChainID", mock.Anything).Return(chainID, nil) + + s := NewSendOnlyNode(lggr, url.URL{}, t.Name(), chainID, client) + + defer func() { assert.NoError(t, s.Close()) }() + err := s.Start(tests.Context(t)) + require.NoError(t, err) + + assert.Equal(t, nodeStateUnreachable, s.State()) + tests.AssertLogCountEventually(t, observedLogs, fmt.Sprintf("Verify failed: %v", expectedError), failuresCount) + tests.AssertEventually(t, func() bool { + return s.State() == nodeStateAlive + }) + }) + t.Run("Can recover from chainID mismatch", func(t *testing.T) { + t.Parallel() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.WarnLevel) + client := newMockSendOnlyClient[types.ID](t) + client.On("Close").Once() + client.On("DialHTTP").Return(nil).Once() + configuredChainID := types.NewIDFromInt(11) + rpcChainID := types.NewIDFromInt(20) + const failuresCount = 2 + client.On("ChainID", mock.Anything).Return(rpcChainID, nil).Times(failuresCount) + client.On("ChainID", mock.Anything).Return(configuredChainID, nil) + s := NewSendOnlyNode(lggr, url.URL{}, t.Name(), configuredChainID, client) + + defer func() { assert.NoError(t, s.Close()) }() + err := s.Start(tests.Context(t)) + require.NoError(t, err) + + assert.Equal(t, nodeStateInvalidChainID, s.State()) + tests.AssertLogCountEventually(t, observedLogs, "sendonly rpc ChainID doesn't match local chain ID", failuresCount) + tests.AssertEventually(t, func() bool { + return s.State() == nodeStateAlive + }) + }) + t.Run("Start with Random ChainID", func(t *testing.T) { + t.Parallel() + lggr, observedLogs := logger.TestLoggerObserved(t, zap.WarnLevel) + client := newMockSendOnlyClient[types.ID](t) + client.On("Close").Once() + client.On("DialHTTP").Return(nil).Once() + configuredChainID := types.RandomID() + client.On("ChainID", mock.Anything).Return(configuredChainID, nil) + s := NewSendOnlyNode(lggr, url.URL{}, t.Name(), configuredChainID, client) + + defer func() { assert.NoError(t, s.Close()) }() + err := s.Start(tests.Context(t)) + assert.NoError(t, err) + tests.AssertEventually(t, func() bool { + return s.State() == nodeStateAlive + }) + assert.Equal(t, 0, observedLogs.Len()) // No warnings expected + }) +} diff --git a/common/client/types.go b/common/client/types.go index f3a6029a9e8..0e52f1db72c 100644 --- a/common/client/types.go +++ b/common/client/types.go @@ -11,6 +11,8 @@ import ( ) // RPC includes all the necessary methods for a multi-node client to interact directly with any RPC endpoint. +// +//go:generate mockery --quiet --name RPC --structname mockRPC --inpackage --filename "mock_rpc_test.go" --case=underscore type RPC[ CHAIN_ID types.ID, SEQ types.Sequence, @@ -45,12 +47,16 @@ type RPC[ } // Head is the interface required by the NodeClient +// +//go:generate mockery --quiet --name Head --structname mockHead --filename "mock_head_test.go" --inpackage --case=underscore type Head interface { BlockNumber() int64 BlockDifficulty() *utils.Big } // NodeClient includes all the necessary RPC methods required by a node. +// +//go:generate mockery --quiet --name NodeClient --structname mockNodeClient --filename "mock_node_client_test.go" --inpackage --case=underscore type NodeClient[ CHAIN_ID types.ID, HEAD Head, diff --git a/common/types/test_utils.go b/common/types/test_utils.go new file mode 100644 index 00000000000..40560f7866c --- /dev/null +++ b/common/types/test_utils.go @@ -0,0 +1,16 @@ +package types + +import ( + "math" + "math/big" + "math/rand" +) + +func RandomID() ID { + id := rand.Int63n(math.MaxInt32) + 10000 + return big.NewInt(id) +} + +func NewIDFromInt(id int64) ID { + return big.NewInt(id) +}