From fca1adbad702e5c78f31357dce4e0cdd92c13fd3 Mon Sep 17 00:00:00 2001 From: Manu NALEPA Date: Wed, 20 Mar 2024 04:36:00 +0100 Subject: [PATCH] Re-design `TestStartDiscV5_DiscoverPeersWithSubnets` test (#13766) * `Test_AttSubnets`: Factorize. * `filterPeerForAttSubnet`: `O(n)` ==> `O(1)` * `FindPeersWithSubnet`: Optimize. * `TestStartDiscV5_DiscoverPeersWithSubnets`: Complete re-design. * `broadcastAttestation`: User `log.WithFields`. * `filterPeer`: Refactor comments. * Make deepsource happy. * `TestStartDiscV5_FindPeersWithSubnet`: Add context cancellation. Add some notes on `FindPeersWithSubnet` about this limitation as well. --- beacon-chain/p2p/broadcaster.go | 18 +- beacon-chain/p2p/discovery.go | 33 ++-- beacon-chain/p2p/subnets.go | 39 ++-- beacon-chain/p2p/subnets_test.go | 312 ++++++++++++++++--------------- 4 files changed, 212 insertions(+), 190 deletions(-) diff --git a/beacon-chain/p2p/broadcaster.go b/beacon-chain/p2p/broadcaster.go index 4581bae1bd6e..10b209390800 100644 --- a/beacon-chain/p2p/broadcaster.go +++ b/beacon-chain/p2p/broadcaster.go @@ -15,6 +15,7 @@ import ( "github.com/prysmaticlabs/prysm/v5/monitoring/tracing" ethpb "github.com/prysmaticlabs/prysm/v5/proto/prysm/v1alpha1" "github.com/prysmaticlabs/prysm/v5/time/slots" + "github.com/sirupsen/logrus" "go.opencensus.io/trace" "google.golang.org/protobuf/proto" ) @@ -68,7 +69,7 @@ func (s *Service) BroadcastAttestation(ctx context.Context, subnet uint64, att * } // Non-blocking broadcast, with attempts to discover a subnet peer if none available. - go s.broadcastAttestation(ctx, subnet, att, forkDigest) + go s.internalBroadcastAttestation(ctx, subnet, att, forkDigest) return nil } @@ -94,8 +95,8 @@ func (s *Service) BroadcastSyncCommitteeMessage(ctx context.Context, subnet uint return nil } -func (s *Service) broadcastAttestation(ctx context.Context, subnet uint64, att *ethpb.Attestation, forkDigest [4]byte) { - _, span := trace.StartSpan(ctx, "p2p.broadcastAttestation") +func (s *Service) internalBroadcastAttestation(ctx context.Context, subnet uint64, att *ethpb.Attestation, forkDigest [4]byte) { + _, span := trace.StartSpan(ctx, "p2p.internalBroadcastAttestation") defer span.End() ctx = trace.NewContext(context.Background(), span) // clear parent context / deadline. @@ -137,7 +138,10 @@ func (s *Service) broadcastAttestation(ctx context.Context, subnet uint64, att * // acceptable threshold, we exit early and do not broadcast it. currSlot := slots.CurrentSlot(uint64(s.genesisTime.Unix())) if att.Data.Slot+params.BeaconConfig().SlotsPerEpoch < currSlot { - log.Warnf("Attestation is too old to broadcast, discarding it. Current Slot: %d , Attestation Slot: %d", currSlot, att.Data.Slot) + log.WithFields(logrus.Fields{ + "attestationSlot": att.Data.Slot, + "currentSlot": currSlot, + }).Warning("Attestation is too old to broadcast, discarding it") return } @@ -218,13 +222,13 @@ func (s *Service) BroadcastBlob(ctx context.Context, subnet uint64, blob *ethpb. } // Non-blocking broadcast, with attempts to discover a subnet peer if none available. - go s.broadcastBlob(ctx, subnet, blob, forkDigest) + go s.internalBroadcastBlob(ctx, subnet, blob, forkDigest) return nil } -func (s *Service) broadcastBlob(ctx context.Context, subnet uint64, blobSidecar *ethpb.BlobSidecar, forkDigest [4]byte) { - _, span := trace.StartSpan(ctx, "p2p.broadcastBlob") +func (s *Service) internalBroadcastBlob(ctx context.Context, subnet uint64, blobSidecar *ethpb.BlobSidecar, forkDigest [4]byte) { + _, span := trace.StartSpan(ctx, "p2p.internalBroadcastBlob") defer span.End() ctx = trace.NewContext(context.Background(), span) // clear parent context / deadline. diff --git a/beacon-chain/p2p/discovery.go b/beacon-chain/p2p/discovery.go index e67d942c4d28..5a08101a28ed 100644 --- a/beacon-chain/p2p/discovery.go +++ b/beacon-chain/p2p/discovery.go @@ -277,58 +277,69 @@ func (s *Service) startDiscoveryV5( // filterPeer validates each node that we retrieve from our dht. We // try to ascertain that the peer can be a valid protocol peer. // Validity Conditions: -// 1. The local node is still actively looking for peers to -// connect to. -// 2. Peer has a valid IP and TCP port set in their enr. -// 3. Peer hasn't been marked as 'bad' -// 4. Peer is not currently active or connected. -// 5. Peer is ready to receive incoming connections. -// 6. Peer's fork digest in their ENR matches that of +// 1. Peer has a valid IP and TCP port set in their enr. +// 2. Peer hasn't been marked as 'bad'. +// 3. Peer is not currently active or connected. +// 4. Peer is ready to receive incoming connections. +// 5. Peer's fork digest in their ENR matches that of // our localnodes. func (s *Service) filterPeer(node *enode.Node) bool { // Ignore nil node entries passed in. if node == nil { return false } - // ignore nodes with no ip address stored. + + // Ignore nodes with no IP address stored. if node.IP() == nil { return false } - // do not dial nodes with their tcp ports not set + + // Ignore nodes with their TCP ports not set. if err := node.Record().Load(enr.WithEntry("tcp", new(enr.TCP))); err != nil { if !enr.IsNotFound(err) { log.WithError(err).Debug("Could not retrieve tcp port") } return false } + peerData, multiAddr, err := convertToAddrInfo(node) if err != nil { log.WithError(err).Debug("Could not convert to peer data") return false } + + // Ignore bad nodes. if s.peers.IsBad(peerData.ID) { return false } + + // Ignore nodes that are already active. if s.peers.IsActive(peerData.ID) { return false } + + // Ignore nodes that are already connected. if s.host.Network().Connectedness(peerData.ID) == network.Connected { return false } + + // Ignore nodes that are not ready to receive incoming connections. if !s.peers.IsReadyToDial(peerData.ID) { return false } + + // Ignore nodes that don't match our fork digest. nodeENR := node.Record() - // Decide whether or not to connect to peer that does not - // match the proper fork ENR data with our local node. if s.genesisValidatorsRoot != nil { if err := s.compareForkENR(nodeENR); err != nil { log.WithError(err).Trace("Fork ENR mismatches between peer and local node") return false } } + // Add peer to peer handler. s.peers.Add(nodeENR, peerData.ID, multiAddr, network.DirUnknown) + return true } diff --git a/beacon-chain/p2p/subnets.go b/beacon-chain/p2p/subnets.go index 89ca590478bd..8d313db8f558 100644 --- a/beacon-chain/p2p/subnets.go +++ b/beacon-chain/p2p/subnets.go @@ -46,9 +46,13 @@ const syncLockerVal = 100 const blobSubnetLockerVal = 110 // FindPeersWithSubnet performs a network search for peers -// subscribed to a particular subnet. Then we try to connect -// with those peers. This method will block until the required amount of -// peers are found, the method only exits in the event of context timeouts. +// subscribed to a particular subnet. Then it tries to connect +// with those peers. This method will block until either: +// - the required amount of peers are found, or +// - the context is terminated. +// On some edge cases, this method may hang indefinitely while peers +// are actually found. In such a case, the user should cancel the context +// and re-run the method again. func (s *Service) FindPeersWithSubnet(ctx context.Context, topic string, index uint64, threshold int) (bool, error) { ctx, span := trace.StartSpan(ctx, "p2p.FindPeersWithSubnet") @@ -73,9 +77,9 @@ func (s *Service) FindPeersWithSubnet(ctx context.Context, topic string, return false, errors.New("no subnet exists for provided topic") } - currNum := len(s.pubsub.ListPeers(topic)) wg := new(sync.WaitGroup) for { + currNum := len(s.pubsub.ListPeers(topic)) if currNum >= threshold { break } @@ -99,7 +103,6 @@ func (s *Service) FindPeersWithSubnet(ctx context.Context, topic string, } // Wait for all dials to be completed. wg.Wait() - currNum = len(s.pubsub.ListPeers(topic)) } return true, nil } @@ -110,18 +113,13 @@ func (s *Service) filterPeerForAttSubnet(index uint64) func(node *enode.Node) bo if !s.filterPeer(node) { return false } + subnets, err := attSubnets(node.Record()) if err != nil { return false } - indExists := false - for _, comIdx := range subnets { - if comIdx == index { - indExists = true - break - } - } - return indExists + + return subnets[index] } } @@ -205,8 +203,10 @@ func initializePersistentSubnets(id enode.ID, epoch primitives.Epoch) error { // // return [compute_subscribed_subnet(node_id, epoch, index) for index in range(SUBNETS_PER_NODE)] func computeSubscribedSubnets(nodeID enode.ID, epoch primitives.Epoch) ([]uint64, error) { - subs := []uint64{} - for i := uint64(0); i < params.BeaconConfig().SubnetsPerNode; i++ { + subnetsPerNode := params.BeaconConfig().SubnetsPerNode + subs := make([]uint64, 0, subnetsPerNode) + + for i := uint64(0); i < subnetsPerNode; i++ { sub, err := computeSubscribedSubnet(nodeID, epoch, i) if err != nil { return nil, err @@ -281,19 +281,20 @@ func initializeSyncCommSubnets(node *enode.LocalNode) *enode.LocalNode { // Reads the attestation subnets entry from a node's ENR and determines // the committee indices of the attestation subnets the node is subscribed to. -func attSubnets(record *enr.Record) ([]uint64, error) { +func attSubnets(record *enr.Record) (map[uint64]bool, error) { bitV, err := attBitvector(record) if err != nil { return nil, err } + committeeIdxs := make(map[uint64]bool) // lint:ignore uintcast -- subnet count can be safely cast to int. if len(bitV) != byteCount(int(attestationSubnetCount)) { - return []uint64{}, errors.Errorf("invalid bitvector provided, it has a size of %d", len(bitV)) + return committeeIdxs, errors.Errorf("invalid bitvector provided, it has a size of %d", len(bitV)) } - var committeeIdxs []uint64 + for i := uint64(0); i < attestationSubnetCount; i++ { if bitV.BitAt(i) { - committeeIdxs = append(committeeIdxs, i) + committeeIdxs[i] = true } } return committeeIdxs, nil diff --git a/beacon-chain/p2p/subnets_test.go b/beacon-chain/p2p/subnets_test.go index d5fc81168541..92f0b38107e0 100644 --- a/beacon-chain/p2p/subnets_test.go +++ b/beacon-chain/p2p/subnets_test.go @@ -3,161 +3,197 @@ package p2p import ( "context" "crypto/rand" + "encoding/hex" + "fmt" "reflect" "testing" "time" - "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" "github.com/libp2p/go-libp2p/core/crypto" "github.com/prysmaticlabs/go-bitfield" "github.com/prysmaticlabs/prysm/v5/beacon-chain/cache" - "github.com/prysmaticlabs/prysm/v5/beacon-chain/startup" "github.com/prysmaticlabs/prysm/v5/cmd/beacon-chain/flags" "github.com/prysmaticlabs/prysm/v5/config/params" - "github.com/prysmaticlabs/prysm/v5/consensus-types/wrapper" ecdsaprysm "github.com/prysmaticlabs/prysm/v5/crypto/ecdsa" - pb "github.com/prysmaticlabs/prysm/v5/proto/prysm/v1alpha1" "github.com/prysmaticlabs/prysm/v5/testing/assert" "github.com/prysmaticlabs/prysm/v5/testing/require" ) -func TestStartDiscV5_DiscoverPeersWithSubnets(t *testing.T) { +func TestStartDiscV5_FindPeersWithSubnet(t *testing.T) { + // Topology of this test: + // + // + // Node 1 (subscribed to subnet 1) --\ + // | + // Node 2 (subscribed to subnet 2) --+--> BootNode (not subscribed to any subnet) <------- Node 0 (not subscribed to any subnet) + // | + // Node 3 (subscribed to subnet 3) --/ + // + // The purpose of this test is to ensure that the "Node 0" (connected only to the boot node) is able to + // find and connect to a node already subscribed to a specific subnet. + // In our case: The node i is subscribed to subnet i, with i = 1, 2, 3 + + // Define the genesis validators root, to ensure everybody is on the same network. + const genesisValidatorRootStr = "0xdeadbeefcafecafedeadbeefcafecafedeadbeefcafecafedeadbeefcafecafe" + genesisValidatorsRoot, err := hex.DecodeString(genesisValidatorRootStr[2:]) + require.NoError(t, err) + + // Create a context. + ctx := context.Background() + + // Use shorter period for testing. + currentPeriod := pollingPeriod + pollingPeriod = 1 * time.Second + defer func() { + pollingPeriod = currentPeriod + }() + + // Create flags. params.SetupTestConfigCleanup(t) - // This test needs to be entirely rewritten and should be done in a follow up PR from #7885. - t.Skip("This test is now failing after PR 7885 due to false positive") gFlags := new(flags.GlobalFlags) - gFlags.MinimumPeersPerSubnet = 4 + gFlags.MinimumPeersPerSubnet = 1 flags.Init(gFlags) + + params.BeaconNetworkConfig().MinimumPeersInSubnetSearch = 1 + // Reset config. defer flags.Init(new(flags.GlobalFlags)) - port := 2000 + + // First, generate a bootstrap node. ipAddr, pkey := createAddrAndPrivKey(t) genesisTime := time.Now() - genesisValidatorsRoot := make([]byte, 32) - s := &Service{ - cfg: &Config{UDPPort: uint(port)}, + + bootNodeService := &Service{ + cfg: &Config{TCPPort: 2000, UDPPort: 3000}, genesisTime: genesisTime, genesisValidatorsRoot: genesisValidatorsRoot, } - bootListener, err := s.createListener(ipAddr, pkey) + + bootNodeForkDigest, err := bootNodeService.currentForkDigest() + require.NoError(t, err) + + bootListener, err := bootNodeService.createListener(ipAddr, pkey) require.NoError(t, err) defer bootListener.Close() - bootNode := bootListener.Self() - // Use shorter period for testing. - currentPeriod := pollingPeriod - pollingPeriod = 1 * time.Second - defer func() { - pollingPeriod = currentPeriod - }() + bootNodeENR := bootListener.Self().String() + + // Create 3 nodes, each subscribed to a different subnet. + // Each node is connected to the boostrap node. + services := make([]*Service, 0, 3) - var listeners []*discover.UDPv5 for i := 1; i <= 3; i++ { - port = 3000 + i - cfg := &Config{ - Discv5BootStrapAddrs: []string{bootNode.String()}, + subnet := uint64(i) + service, err := NewService(ctx, &Config{ + Discv5BootStrapAddrs: []string{bootNodeENR}, MaxPeers: 30, - UDPPort: uint(port), - } - ipAddr, pkey := createAddrAndPrivKey(t) - s = &Service{ - cfg: cfg, - genesisTime: genesisTime, - genesisValidatorsRoot: genesisValidatorsRoot, - } - listener, err := s.startDiscoveryV5(ipAddr, pkey) - assert.NoError(t, err, "Could not start discovery for node") - bitV := bitfield.NewBitvector64() - bitV.SetBitAt(uint64(i), true) + TCPPort: uint(2000 + i), + UDPPort: uint(3000 + i), + }) + + require.NoError(t, err) + + service.genesisTime = genesisTime + service.genesisValidatorsRoot = genesisValidatorsRoot + + nodeForkDigest, err := service.currentForkDigest() + require.NoError(t, err) + require.Equal(t, true, nodeForkDigest == bootNodeForkDigest, "fork digest of the node doesn't match the boot node") + + // Start the service. + service.Start() + // Set the ENR `attnets`, used by Prysm to filter peers by subnet. + bitV := bitfield.NewBitvector64() + bitV.SetBitAt(subnet, true) entry := enr.WithEntry(attSubnetEnrKey, &bitV) - listener.LocalNode().Set(entry) - listeners = append(listeners, listener) + service.dv5Listener.LocalNode().Set(entry) + + // Join and subscribe to the subnet, needed by libp2p. + topic, err := service.pubsub.Join(fmt.Sprintf(AttestationSubnetTopicFormat, bootNodeForkDigest, subnet) + "/ssz_snappy") + require.NoError(t, err) + + _, err = topic.Subscribe() + require.NoError(t, err) + + // Memoize the service. + services = append(services, service) } + + // Stop the services. defer func() { - // Close down all peers. - for _, listener := range listeners { - listener.Close() + for _, service := range services { + err := service.Stop() + require.NoError(t, err) } }() - // Make one service on port 4001. - port = 4001 - gs := startup.NewClockSynchronizer() cfg := &Config{ - Discv5BootStrapAddrs: []string{bootNode.String()}, + Discv5BootStrapAddrs: []string{bootNodeENR}, MaxPeers: 30, - UDPPort: uint(port), - ClockWaiter: gs, + TCPPort: 2010, + UDPPort: 3010, } - s, err = NewService(context.Background(), cfg) + + service, err := NewService(ctx, cfg) require.NoError(t, err) - exitRoutine := make(chan bool) - go func() { - s.Start() - <-exitRoutine + service.genesisTime = genesisTime + service.genesisValidatorsRoot = genesisValidatorsRoot + + service.Start() + defer func() { + err := service.Stop() + require.NoError(t, err) }() - time.Sleep(50 * time.Millisecond) - // Send in a loop to ensure it is delivered (busy wait for the service to subscribe to the state feed). - var vr [32]byte - require.NoError(t, gs.SetClock(startup.NewClock(time.Now(), vr))) - // Wait for the nodes to have their local routing tables to be populated with the other nodes - time.Sleep(6 * discoveryWaitTime) + // Look up 3 different subnets. + exists := make([]bool, 0, 3) + for i := 1; i <= 3; i++ { + subnet := uint64(i) + topic := fmt.Sprintf(AttestationSubnetTopicFormat, bootNodeForkDigest, subnet) - // look up 3 different subnets - ctx := context.Background() - exists, err := s.FindPeersWithSubnet(ctx, "", 1, flags.Get().MinimumPeersPerSubnet) - require.NoError(t, err) - exists2, err := s.FindPeersWithSubnet(ctx, "", 2, flags.Get().MinimumPeersPerSubnet) - require.NoError(t, err) - exists3, err := s.FindPeersWithSubnet(ctx, "", 3, flags.Get().MinimumPeersPerSubnet) - require.NoError(t, err) - if !exists || !exists2 || !exists3 { - t.Fatal("Peer with subnet doesn't exist") - } + exist := false - // Update ENR of a peer. - testService := &Service{ - dv5Listener: listeners[0], - metaData: wrapper.WrappedMetadataV0(&pb.MetaDataV0{ - Attnets: bitfield.NewBitvector64(), - }), - } - cache.SubnetIDs.AddAttesterSubnetID(0, 10) - testService.RefreshENR() - time.Sleep(2 * time.Second) + // This for loop is used to ensure we don't get stuck in `FindPeersWithSubnet`. + // Read the documentation of `FindPeersWithSubnet` for more details. + for j := 0; j < 3; j++ { + ctxWithTimeOut, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() - exists, err = s.FindPeersWithSubnet(ctx, "", 2, flags.Get().MinimumPeersPerSubnet) - require.NoError(t, err) + exist, err = service.FindPeersWithSubnet(ctxWithTimeOut, topic, subnet, 1) + require.NoError(t, err) - assert.Equal(t, true, exists, "Peer with subnet doesn't exist") - assert.NoError(t, s.Stop()) - exitRoutine <- true + if exist { + break + } + } + + require.NoError(t, err) + exists = append(exists, exist) + + } + + // Check if all peers are found. + for _, exist := range exists { + require.Equal(t, true, exist, "Peer with subnet doesn't exist") + } } func Test_AttSubnets(t *testing.T) { params.SetupTestConfigCleanup(t) tests := []struct { name string - record func(t *testing.T) *enr.Record + record func(localNode *enode.LocalNode) *enr.Record want []uint64 wantErr bool errContains string }{ { name: "valid record", - record: func(t *testing.T) *enr.Record { - db, err := enode.OpenDB("") - assert.NoError(t, err) - priv, _, err := crypto.GenerateSecp256k1Key(rand.Reader) - assert.NoError(t, err) - convertedKey, err := ecdsaprysm.ConvertFromInterfacePrivKey(priv) - assert.NoError(t, err) - localNode := enode.NewLocalNode(db, convertedKey) + record: func(localNode *enode.LocalNode) *enr.Record { localNode = initializeAttSubnets(localNode) return localNode.Node().Record() }, @@ -166,14 +202,7 @@ func Test_AttSubnets(t *testing.T) { }, { name: "too small subnet", - record: func(t *testing.T) *enr.Record { - db, err := enode.OpenDB("") - assert.NoError(t, err) - priv, _, err := crypto.GenerateSecp256k1Key(rand.Reader) - assert.NoError(t, err) - convertedKey, err := ecdsaprysm.ConvertFromInterfacePrivKey(priv) - assert.NoError(t, err) - localNode := enode.NewLocalNode(db, convertedKey) + record: func(localNode *enode.LocalNode) *enr.Record { entry := enr.WithEntry(attSubnetEnrKey, []byte{}) localNode.Set(entry) return localNode.Node().Record() @@ -184,14 +213,7 @@ func Test_AttSubnets(t *testing.T) { }, { name: "half sized subnet", - record: func(t *testing.T) *enr.Record { - db, err := enode.OpenDB("") - assert.NoError(t, err) - priv, _, err := crypto.GenerateSecp256k1Key(rand.Reader) - assert.NoError(t, err) - convertedKey, err := ecdsaprysm.ConvertFromInterfacePrivKey(priv) - assert.NoError(t, err) - localNode := enode.NewLocalNode(db, convertedKey) + record: func(localNode *enode.LocalNode) *enr.Record { entry := enr.WithEntry(attSubnetEnrKey, make([]byte, 4)) localNode.Set(entry) return localNode.Node().Record() @@ -202,14 +224,7 @@ func Test_AttSubnets(t *testing.T) { }, { name: "too large subnet", - record: func(t *testing.T) *enr.Record { - db, err := enode.OpenDB("") - assert.NoError(t, err) - priv, _, err := crypto.GenerateSecp256k1Key(rand.Reader) - assert.NoError(t, err) - convertedKey, err := ecdsaprysm.ConvertFromInterfacePrivKey(priv) - assert.NoError(t, err) - localNode := enode.NewLocalNode(db, convertedKey) + record: func(localNode *enode.LocalNode) *enr.Record { entry := enr.WithEntry(attSubnetEnrKey, make([]byte, byteCount(int(attestationSubnetCount))+1)) localNode.Set(entry) return localNode.Node().Record() @@ -220,14 +235,7 @@ func Test_AttSubnets(t *testing.T) { }, { name: "very large subnet", - record: func(t *testing.T) *enr.Record { - db, err := enode.OpenDB("") - assert.NoError(t, err) - priv, _, err := crypto.GenerateSecp256k1Key(rand.Reader) - assert.NoError(t, err) - convertedKey, err := ecdsaprysm.ConvertFromInterfacePrivKey(priv) - assert.NoError(t, err) - localNode := enode.NewLocalNode(db, convertedKey) + record: func(localNode *enode.LocalNode) *enr.Record { entry := enr.WithEntry(attSubnetEnrKey, make([]byte, byteCount(int(attestationSubnetCount))+100)) localNode.Set(entry) return localNode.Node().Record() @@ -238,14 +246,7 @@ func Test_AttSubnets(t *testing.T) { }, { name: "single subnet", - record: func(t *testing.T) *enr.Record { - db, err := enode.OpenDB("") - assert.NoError(t, err) - priv, _, err := crypto.GenerateSecp256k1Key(rand.Reader) - assert.NoError(t, err) - convertedKey, err := ecdsaprysm.ConvertFromInterfacePrivKey(priv) - assert.NoError(t, err) - localNode := enode.NewLocalNode(db, convertedKey) + record: func(localNode *enode.LocalNode) *enr.Record { bitV := bitfield.NewBitvector64() bitV.SetBitAt(0, true) entry := enr.WithEntry(attSubnetEnrKey, bitV.Bytes()) @@ -257,17 +258,10 @@ func Test_AttSubnets(t *testing.T) { }, { name: "multiple subnets", - record: func(t *testing.T) *enr.Record { - db, err := enode.OpenDB("") - assert.NoError(t, err) - priv, _, err := crypto.GenerateSecp256k1Key(rand.Reader) - assert.NoError(t, err) - convertedKey, err := ecdsaprysm.ConvertFromInterfacePrivKey(priv) - assert.NoError(t, err) - localNode := enode.NewLocalNode(db, convertedKey) + record: func(localNode *enode.LocalNode) *enr.Record { bitV := bitfield.NewBitvector64() for i := uint64(0); i < bitV.Len(); i++ { - // skip 2 subnets + // Keep only odd subnets. if (i+1)%2 == 0 { continue } @@ -285,14 +279,7 @@ func Test_AttSubnets(t *testing.T) { }, { name: "all subnets", - record: func(t *testing.T) *enr.Record { - db, err := enode.OpenDB("") - assert.NoError(t, err) - priv, _, err := crypto.GenerateSecp256k1Key(rand.Reader) - assert.NoError(t, err) - convertedKey, err := ecdsaprysm.ConvertFromInterfacePrivKey(priv) - assert.NoError(t, err) - localNode := enode.NewLocalNode(db, convertedKey) + record: func(localNode *enode.LocalNode) *enr.Record { bitV := bitfield.NewBitvector64() for i := uint64(0); i < bitV.Len(); i++ { bitV.SetBitAt(i, true) @@ -309,16 +296,35 @@ func Test_AttSubnets(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := attSubnets(tt.record(t)) + db, err := enode.OpenDB("") + assert.NoError(t, err) + + priv, _, err := crypto.GenerateSecp256k1Key(rand.Reader) + assert.NoError(t, err) + + convertedKey, err := ecdsaprysm.ConvertFromInterfacePrivKey(priv) + assert.NoError(t, err) + + localNode := enode.NewLocalNode(db, convertedKey) + record := tt.record(localNode) + + got, err := attSubnets(record) if (err != nil) != tt.wantErr { t.Errorf("syncSubnets() error = %v, wantErr %v", err, tt.wantErr) return } + if tt.wantErr { assert.ErrorContains(t, tt.errContains, err) } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("syncSubnets() got = %v, want %v", got, tt.want) + + want := make(map[uint64]bool, len(tt.want)) + for _, subnet := range tt.want { + want[subnet] = true + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("syncSubnets() got = %v, want %v", got, want) } }) }