From f732a03a6a2711bf78d29367e81f87e8290104f8 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Tue, 24 Sep 2024 21:00:06 +0200 Subject: [PATCH] Address code review comments Signed-off-by: Yacov Manevich --- snow/engine/snowman/straggler_detect.go | 49 ++++++++++---------- snow/engine/snowman/straggler_detect_test.go | 30 ++++++------ snow/networking/handler/health.go | 4 +- 3 files changed, 40 insertions(+), 43 deletions(-) diff --git a/snow/engine/snowman/straggler_detect.go b/snow/engine/snowman/straggler_detect.go index 736aa9f4100b..00653c7ab467 100644 --- a/snow/engine/snowman/straggler_detect.go +++ b/snow/engine/snowman/straggler_detect.go @@ -18,8 +18,9 @@ import ( ) const ( - minStragglerCheckInterval = time.Second - stakeThresholdForStragglerSuspicion = 0.75 + minStragglerCheckInterval = time.Second + stakeThresholdForStragglerSuspicion = 0.75 + minimumStakeThresholdRequiredForNetworkInfo = 0.8 ) type stragglerDetectorConfig struct { @@ -32,16 +33,16 @@ type stragglerDetectorConfig struct { // log logs events log logging.Logger - // minConfirmationThreshold is the minimum percentage that below it, we do not check if we are stragglers. + // minConfirmationThreshold is the minimum stake percentage that below it, we do not check if we are stragglers. minConfirmationThreshold float64 - // connectedPercent returns the percent of connected nodes. + // connectedPercent returns the stake percentage of connected nodes. connectedPercent func() float64 // connectedValidators returns a set of tuples of NodeID and corresponding weight. connectedValidators func() set.Set[ids.NodeWeight] - // lastAcceptedByNodeID returns the last reported height a node has reported, or false if it is unknown. + // lastAcceptedByNodeID returns the last accepted height a node has reported, or false if it is unknown. lastAcceptedByNodeID func(id ids.NodeID) (ids.ID, bool) // processing returns whether this block ID is known and its descendants have not yet been accepted by consensus. @@ -58,7 +59,7 @@ type stragglerDetectorConfig struct { // or false if it fails from some reason. getSnapshot func() (snapshot, bool) - // haveWeFailedCatchingUp returns whether we have not replicated enough blocks of the given snapshot + // haveWeFailedCatchingUp returns whether we have not replicated enough blocks of the given snapshot haveWeFailedCatchingUp func(snapshot) bool } @@ -174,27 +175,27 @@ func (sd *stragglerDetector) failedCatchingUp(s snapshot) bool { } func (sd *stragglerDetector) getNetworkSnapshot() (snapshot, bool) { - nodeWeight2lastAccepted, totalValidatorWeight, _ := sd.getNetworkInfo() - if len(nodeWeight2lastAccepted) == 0 { + nodeWeightToLastAccepted, totalValidatorWeight := sd.getNetworkInfo() + if len(nodeWeightToLastAccepted) == 0 { return snapshot{}, false } ourLastAcceptedBlock := sd.lastAccepted() - prevLastAcceptedCount := len(nodeWeight2lastAccepted) - for k, v := range nodeWeight2lastAccepted { + prevLastAcceptedCount := len(nodeWeightToLastAccepted) + for k, v := range nodeWeightToLastAccepted { if ourLastAcceptedBlock.Compare(v) == 0 { - delete(nodeWeight2lastAccepted, k) + delete(nodeWeightToLastAccepted, k) } } - newLastAcceptedCount := len(nodeWeight2lastAccepted) + newLastAcceptedCount := len(nodeWeightToLastAccepted) sd.log.Trace("Excluding nodes with our own height", zap.Int("prev", prevLastAcceptedCount), zap.Int("new", newLastAcceptedCount)) // Ensure we have collected last accepted blocks that are not our own last accepted block // for at least 80% stake of the total weight we are connected to. - totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeight2lastAccepted.totalWeight() + totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeightToLastAccepted.totalWeight() if err != nil { sd.log.Error("Failed computing total weight", zap.Error(err)) return snapshot{}, false @@ -209,7 +210,7 @@ func (sd *stragglerDetector) getNetworkSnapshot() (snapshot, bool) { } snap := snapshot{ - nodeWeights2Blocks: nodeWeight2lastAccepted, + nodeWeights2Blocks: nodeWeightToLastAccepted, totalValidatorWeight: totalValidatorWeight, } @@ -220,14 +221,14 @@ func (sd *stragglerDetector) getNetworkSnapshot() (snapshot, bool) { return snapshot{}, false } -func (sd *stragglerDetector) getNetworkInfo() (nodeWeights2Blocks, uint64, uint64) { +func (sd *stragglerDetector) getNetworkInfo() (nodeWeights2Blocks, uint64) { ratio := sd.connectedPercent() if ratio < sd.minConfirmationThreshold { // We don't know for sure whether we're behind or not. // Even if we're behind, it's pointless to act before we have established - // connectivity with enough validators. + // connectivity with enough validators. sd.log.Verbo("not enough connected stake to determine network info", zap.Float64("ratio", ratio)) - return nil, 0, 0 + return nil, 0 } validators := nodeWeights(sd.connectedValidators().List()) @@ -245,29 +246,29 @@ func (sd *stragglerDetector) getNetworkInfo() (nodeWeights2Blocks, uint64, uint6 totalValidatorWeight, err := validators.totalWeight() if err != nil { sd.log.Error("Failed computing total weight", zap.Error(err)) - return nil, 0, 0 + return nil, 0 } totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeight2lastAccepted.totalWeight() if err != nil { sd.log.Error("Failed computing total weight", zap.Error(err)) - return nil, 0, 0 + return nil, 0 } if totalValidatorWeight == 0 { sd.log.Trace("Connected to zero weight") - return nil, 0, 0 + return nil, 0 } ratio = float64(totalWeightWeKnowItsLastAcceptedBlock) / float64(totalValidatorWeight) - // Ensure we have collected last accepted blocks for at least 80% stake of the total weight we are connected to. - if ratio < 0.8 { + // Ensure we have collected last accepted blocks for at least 80% (or so) stake of the total weight we are connected to. + if ratio < minimumStakeThresholdRequiredForNetworkInfo { sd.log.Trace("Not collected enough information about last accepted blocks for the validators we are connected to", zap.Float64("ratio", ratio)) - return nil, 0, 0 + return nil, 0 } - return nodeWeight2lastAccepted, totalValidatorWeight, totalWeightWeKnowItsLastAcceptedBlock + return nodeWeight2lastAccepted, totalValidatorWeight } type snapshot struct { diff --git a/snow/engine/snowman/straggler_detect_test.go b/snow/engine/snowman/straggler_detect_test.go index 5d68711b25ff..516349b63f5b 100644 --- a/snow/engine/snowman/straggler_detect_test.go +++ b/snow/engine/snowman/straggler_detect_test.go @@ -51,11 +51,9 @@ func TestNodeWeights2Blocks(t *testing.T) { } func TestGetNetworkSnapshot(t *testing.T) { - n1, err := ids.NodeIDFromString("NodeID-N5gc5soT3Gpr98NKpqvQQG2SgGrVPL64w") - require.NoError(t, err) + n1 := ids.GenerateTestNodeID() - n2, err := ids.NodeIDFromString("NodeID-NpagUxt6KQiwPch9Sd4osv8kD1TZnkjdk") - require.NoError(t, err) + n2 := ids.GenerateTestNodeID() connectedValidators := func(s []ids.NodeWeight) func() set.Set[ids.NodeWeight] { return func() set.Set[ids.NodeWeight] { @@ -150,11 +148,9 @@ func TestGetNetworkSnapshot(t *testing.T) { } func TestFailedCatchingUp(t *testing.T) { - n1, err := ids.NodeIDFromString("NodeID-N5gc5soT3Gpr98NKpqvQQG2SgGrVPL64w") - require.NoError(t, err) + n1 := ids.GenerateTestNodeID() - n2, err := ids.NodeIDFromString("NodeID-NpagUxt6KQiwPch9Sd4osv8kD1TZnkjdk") - require.NoError(t, err) + n2 := ids.GenerateTestNodeID() for _, testCase := range []struct { description string @@ -300,25 +296,25 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { for _, testCase := range []struct { description string timeAdvanced time.Duration - evalExtraAssertions func() + evalExtraAssertions func(t *testing.T) expectedStragglingTime time.Duration snapshotsRead []snapshot haveWeFailedCatchingUpReturns bool }{ { description: "First invocation only sets the time", - evalExtraAssertions: func() {}, + evalExtraAssertions: func(t *testing.T) {}, }, { description: "Should not check yet, as it is not time yet", timeAdvanced: time.Millisecond * 500, - evalExtraAssertions: func() {}, + evalExtraAssertions: func(t *testing.T) {}, }, { description: "Advance time some more, so now we should check", timeAdvanced: time.Millisecond * 501, snapshotsRead: []snapshot{{}}, - evalExtraAssertions: func() { + evalExtraAssertions: func(t *testing.T) { require.Contains(t, buff.String(), "No node snapshot obtained") }, }, @@ -326,7 +322,7 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { description: "Advance time some more to the first check where the snapshot isn't empty", timeAdvanced: time.Second * 2, snapshotsRead: []snapshot{nonEmptySnap}, - evalExtraAssertions: func() { + evalExtraAssertions: func(t *testing.T) { require.Empty(t, buff.String()) }, }, @@ -335,7 +331,7 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { timeAdvanced: time.Second * 2, expectedStragglingTime: time.Second * 2, haveWeFailedCatchingUpReturns: true, - evalExtraAssertions: func() { + evalExtraAssertions: func(t *testing.T) { require.Empty(t, sd.prevSnapshot) }, }, @@ -346,12 +342,12 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { // We carry over the total straggling time from previous testCase to this check, // as we need the next check to nullify it. expectedStragglingTime: time.Second * 2, - evalExtraAssertions: func() {}, + evalExtraAssertions: func(t *testing.T) {}, }, { description: "The fourth check returns we have succeeded in catching up", timeAdvanced: time.Second * 2, - evalExtraAssertions: func() {}, + evalExtraAssertions: func(t *testing.T) {}, }, } { t.Run(testCase.description, func(t *testing.T) { @@ -365,7 +361,7 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { haveWeFailedCatchingUpReturns = testCase.haveWeFailedCatchingUpReturns require.Equal(t, testCase.expectedStragglingTime, sd.CheckIfWeAreStragglingBehind()) - testCase.evalExtraAssertions() + testCase.evalExtraAssertions(t) // Cleanup the log buffer, and make sure no snapshots remain for next testCase. buff.Reset() diff --git a/snow/networking/handler/health.go b/snow/networking/handler/health.go index fbbc8113e2d1..4c43e0ead003 100644 --- a/snow/networking/handler/health.go +++ b/snow/networking/handler/health.go @@ -66,10 +66,10 @@ func (h *handler) getDisconnectedValidators() set.Set[ids.NodeID] { connectedVdrs := h.peerTracker.ConnectedValidators() // vdrs - connectedVdrs is equal to the disconnectedVdrs vdrs.Difference(connectedVdrs) - return trimWeights(vdrs) + return withoutWeights(vdrs) } -func trimWeights(weights set.Set[ids.NodeWeight]) set.Set[ids.NodeID] { +func withoutWeights(weights set.Set[ids.NodeWeight]) set.Set[ids.NodeID] { var res set.Set[ids.NodeID] for _, nw := range weights.List() { res.Add(nw.Node)