diff --git a/network/p2pNetwork.go b/network/p2pNetwork.go index cc37961d51..3d6f856770 100644 --- a/network/p2pNetwork.go +++ b/network/p2pNetwork.go @@ -156,10 +156,30 @@ func (n *P2PNetwork) Stop() { n.wsPeersConnectivityCheckTicker = nil } n.ctxCancel() + n.innerStop() n.service.Close() n.wg.Wait() } +// innerStop context for shutting down peers +func (n *P2PNetwork) innerStop() { + n.wsPeersLock.Lock() + defer n.wsPeersLock.Unlock() + closeGroup := sync.WaitGroup{} + closeGroup.Add(len(n.wsPeers)) + deadline := time.Now().Add(peerDisconnectionAckDuration) + for peerID, peer := range n.wsPeers { + // we need to both close the wsPeer and close the p2p connection + go closeWaiter(&closeGroup, peer, deadline) + err := n.service.ClosePeer(peerID) + if err != nil { + n.log.Warnf("Error closing peer %s: %v", peerID, err) + } + delete(n.wsPeers, peerID) + } + closeGroup.Wait() +} + func (n *P2PNetwork) meshThread() { defer n.wg.Done() timer := time.NewTicker(meshThreadInterval) @@ -220,6 +240,15 @@ func (n *P2PNetwork) Disconnect(badnode Peer) { if err != nil { n.log.Warnf("Error disconnecting from peer %s: %v", node, err) } + n.wsPeersLock.Lock() + defer n.wsPeersLock.Unlock() + if wsPeer, ok := n.wsPeers[node]; ok { + wsPeer.CloseAndWait(time.Now().Add(peerDisconnectionAckDuration)) + delete(n.wsPeers, node) + } else { + n.log.Warnf("Could not find wsPeer reference for peer %s", node) + } + default: n.log.Warnf("Unknown peer type %T", badnode) } @@ -382,7 +411,7 @@ func (n *P2PNetwork) txTopicHandleLoop() { for { msg, err := sub.Next(n.ctx) if err != nil { - if err != pubsub.ErrSubscriptionCancelled { + if err != pubsub.ErrSubscriptionCancelled && err != context.Canceled { n.log.Errorf("Error reading from subscription %v, peerId %s", err, n.service.ID()) } sub.Cancel() diff --git a/network/p2pNetwork_test.go b/network/p2pNetwork_test.go index c3c0522a7e..86c6de1c40 100644 --- a/network/p2pNetwork_test.go +++ b/network/p2pNetwork_test.go @@ -63,7 +63,9 @@ func TestP2PSubmitTX(t *testing.T) { require.Eventually( t, func() bool { - return len(netA.service.ListPeersForTopic(p2p.TXTopicName)) == 2 && len(netB.service.ListPeersForTopic(p2p.TXTopicName)) == 1 && len(netC.service.ListPeersForTopic(p2p.TXTopicName)) == 1 + return len(netA.service.ListPeersForTopic(p2p.TXTopicName)) == 2 && + len(netB.service.ListPeersForTopic(p2p.TXTopicName)) == 1 && + len(netC.service.ListPeersForTopic(p2p.TXTopicName)) == 1 }, 2*time.Second, 50*time.Millisecond,