diff --git a/core/capabilities/integration_tests/mock_dispatcher.go b/core/capabilities/integration_tests/mock_dispatcher.go index 371615b74ab..907c62dd3c3 100644 --- a/core/capabilities/integration_tests/mock_dispatcher.go +++ b/core/capabilities/integration_tests/mock_dispatcher.go @@ -2,6 +2,7 @@ package integration_tests import ( "context" + "fmt" "sync" "testing" "time" @@ -14,30 +15,47 @@ import ( "google.golang.org/protobuf/proto" ) -type receiverKey struct { - capabilityId string - donId uint32 -} - // testAsyncMessageBroker backs the dispatchers created for each node in the test and effectively // acts as the rageP2P network layer. type testAsyncMessageBroker struct { services.StateMachine t *testing.T - nodes map[p2ptypes.PeerID]*dispatcherNode + chanBufferSize int + stopCh services.StopChan + wg sync.WaitGroup - sendCh chan *remotetypes.MessageBody + peerIDToBrokerNode map[p2ptypes.PeerID]*brokerNode - chanBufferSize int + mux sync.Mutex +} - stopCh services.StopChan - wg sync.WaitGroup +func newTestAsyncMessageBroker(t *testing.T, chanBufferSize int) *testAsyncMessageBroker { + return &testAsyncMessageBroker{ + t: t, + stopCh: make(services.StopChan), + chanBufferSize: chanBufferSize, + peerIDToBrokerNode: make(map[p2ptypes.PeerID]*brokerNode), + } +} + +func (a *testAsyncMessageBroker) Start(ctx context.Context) error { + return a.StartOnce("testAsyncMessageBroker", func() error { + return nil + }) +} + +func (a *testAsyncMessageBroker) Close() error { + return a.StopOnce("testAsyncMessageBroker", func() error { + close(a.stopCh) + a.wg.Wait() + return nil + }) } // NewDispatcherForNode creates a new dispatcher for a node with the given peer ID. func (a *testAsyncMessageBroker) NewDispatcherForNode(nodePeerID p2ptypes.PeerID) remotetypes.Dispatcher { - return &nodeDispatcher{ + return &brokerDispatcher{ callerPeerID: nodePeerID, broker: a, } @@ -51,100 +69,82 @@ func (a *testAsyncMessageBroker) Name() string { return "testAsyncMessageBroker" } -func newTestAsyncMessageBroker(t *testing.T, chanBufferSize int) *testAsyncMessageBroker { - return &testAsyncMessageBroker{ - t: t, - nodes: make(map[p2ptypes.PeerID]*dispatcherNode), - stopCh: make(services.StopChan), - sendCh: make(chan *remotetypes.MessageBody, chanBufferSize), - chanBufferSize: chanBufferSize, +func (a *testAsyncMessageBroker) registerReceiverNode(nodePeerID p2ptypes.PeerID, capabilityId string, capabilityDonID uint32, receiver remotetypes.Receiver) { + a.mux.Lock() + defer a.mux.Unlock() + + node, ok := a.peerIDToBrokerNode[nodePeerID] + if !ok { + node = a.newNode() + a.peerIDToBrokerNode[nodePeerID] = node + } + + node.registerReceiverCh <- ®isterReceiverRequest{ + receiverKey: receiverKey{ + capabilityId: capabilityId, + donId: capabilityDonID, + }, + receiver: receiver, } } -func (a *testAsyncMessageBroker) Start(ctx context.Context) error { - return a.StartOnce("testAsyncMessageBroker", func() error { - a.wg.Add(1) - go func() { - defer a.wg.Done() - - for { - select { - case <-a.stopCh: - return - case msg := <-a.sendCh: - peerID := toPeerID(msg.Receiver) - node, ok := a.nodes[peerID] - if !ok { - panic("node not found for peer id") - } - - node.receiveCh <- msg - } - } - }() - return nil - }) +func (a *testAsyncMessageBroker) Send(msg *remotetypes.MessageBody) { + peerID := toPeerID(msg.Receiver) + node, ok := a.peerIDToBrokerNode[peerID] + if !ok { + panic(fmt.Sprintf("node not found for peer ID %v", peerID)) + } + + node.receiveCh <- msg } -func (a *testAsyncMessageBroker) Close() error { - return a.StopOnce("testAsyncMessageBroker", func() error { - close(a.stopCh) +type brokerNode struct { + registerReceiverCh chan *registerReceiverRequest + receiveCh chan *remotetypes.MessageBody +} - a.wg.Wait() - return nil - }) +type receiverKey struct { + capabilityId string + donId uint32 } -type dispatcherNode struct { - receivers map[receiverKey]remotetypes.Receiver - receiveCh chan *remotetypes.MessageBody +type registerReceiverRequest struct { + receiverKey + receiver remotetypes.Receiver } -func (a *testAsyncMessageBroker) registerReceiverNode(nodePeerID p2ptypes.PeerID, capabilityId string, capabilityDonID uint32, receiver remotetypes.Receiver) { - key := receiverKey{ - capabilityId: capabilityId, - donId: capabilityDonID, +func (a *testAsyncMessageBroker) newNode() *brokerNode { + result := &brokerNode{ + receiveCh: make(chan *remotetypes.MessageBody, a.chanBufferSize), + registerReceiverCh: make(chan *registerReceiverRequest, a.chanBufferSize), } - node, nodeExists := a.nodes[nodePeerID] - if !nodeExists { - node = &dispatcherNode{ - receivers: make(map[receiverKey]remotetypes.Receiver), - receiveCh: make(chan *remotetypes.MessageBody, a.chanBufferSize), - } - - a.wg.Add(1) - go func() { - defer a.wg.Done() - - for { - select { - case <-a.stopCh: - return - case msg := <-node.receiveCh: - k := receiverKey{ - capabilityId: msg.CapabilityId, - donId: msg.CapabilityDonId, - } - - r, ok := node.receivers[k] - if !ok { - panic("receiver not found for key") - } - - r.Receive(tests.Context(a.t), msg) + a.wg.Add(1) + go func() { + defer a.wg.Done() + receivers := make(map[receiverKey]remotetypes.Receiver) + for { + select { + case <-a.stopCh: + return + case msg := <-result.receiveCh: + k := receiverKey{ + capabilityId: msg.CapabilityId, + donId: msg.CapabilityDonId, } - } - }() - a.nodes[nodePeerID] = node - } - - node.receivers[key] = receiver -} + r, ok := receivers[k] + if !ok { + panic(fmt.Sprintf("receiver not found for key %+v", k)) + } -func (a *testAsyncMessageBroker) Send(msg *remotetypes.MessageBody) { - a.sendCh <- msg + r.Receive(tests.Context(a.t), msg) + case reg := <-result.registerReceiverCh: + receivers[reg.receiverKey] = reg.receiver + } + } + }() + return result } func toPeerID(id []byte) p2ptypes.PeerID { @@ -155,12 +155,12 @@ type broker interface { Send(msg *remotetypes.MessageBody) } -type nodeDispatcher struct { +type brokerDispatcher struct { callerPeerID p2ptypes.PeerID broker broker } -func (t *nodeDispatcher) Send(peerID p2ptypes.PeerID, msgBody *remotetypes.MessageBody) error { +func (t *brokerDispatcher) Send(peerID p2ptypes.PeerID, msgBody *remotetypes.MessageBody) error { clonedMsg := proto.Clone(msgBody).(*remotetypes.MessageBody) clonedMsg.Version = 1 clonedMsg.Sender = t.callerPeerID[:] @@ -170,8 +170,8 @@ func (t *nodeDispatcher) Send(peerID p2ptypes.PeerID, msgBody *remotetypes.Messa return nil } -func (t *nodeDispatcher) SetReceiver(capabilityId string, donId uint32, receiver remotetypes.Receiver) error { +func (t *brokerDispatcher) SetReceiver(capabilityId string, donId uint32, receiver remotetypes.Receiver) error { t.broker.(*testAsyncMessageBroker).registerReceiverNode(t.callerPeerID, capabilityId, donId, receiver) return nil } -func (t *nodeDispatcher) RemoveReceiver(capabilityId string, donId uint32) {} +func (t *brokerDispatcher) RemoveReceiver(capabilityId string, donId uint32) {} diff --git a/core/capabilities/integration_tests/mock_libocr.go b/core/capabilities/integration_tests/mock_libocr.go index a36686c6bde..2c786dc28ea 100644 --- a/core/capabilities/integration_tests/mock_libocr.go +++ b/core/capabilities/integration_tests/mock_libocr.go @@ -18,7 +18,7 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/services/keystore/keys/ocr2key" ) -type node struct { +type libocrNode struct { ocr3types.ReportingPlugin[[]byte] *ocr3.ContractTransmitter key ocr2key.KeyBundle @@ -27,7 +27,7 @@ type node struct { // mockLibOCR is a mock libocr implementation for testing purposes that simulates libocr protocol rounds without having // to setup the libocr network type mockLibOCR struct { - nodes []*node + nodes []*libocrNode f uint8 } @@ -55,7 +55,7 @@ func (m *mockLibOCR) Start(ctx context.Context, t *testing.T, protocolRoundInter } func (m *mockLibOCR) AddNode(plugin ocr3types.ReportingPlugin[[]byte], transmitter *ocr3.ContractTransmitter, key ocr2key.KeyBundle) { - m.nodes = append(m.nodes, &node{plugin, transmitter, key}) + m.nodes = append(m.nodes, &libocrNode{plugin, transmitter, key}) } func (m *mockLibOCR) simulateProtocolRound(ctx context.Context) error {