From 77ff3610890806891aae9f1bc5398e3f7768bcc1 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Wed, 11 Sep 2024 23:20:34 +0200 Subject: [PATCH 1/6] Detect if our node is behind the majority This commit adds a mechanism that detects that our node is behind the majority of the stake. The intent is to later have this mechanism be the trigger for the bootstrapping mechanism. Currently, the bootstrapping mechanism is only active upon node boot, but not at a later point. The mechanism works in the following manner: - It intercepts the snowman engine's Chits message handling, and upon every reception of the Chits message, the mechanism that detects if the node is a straggler (a node with a ledger height behind the rest) may be invoked, if it wasn't invoked too recently. - The mechanism draws statistics from the validators known to it, and computes the latest accepted block for each validator. - The mechanism then proceeds to determine which blocks are pending to be processed (a block pending to be processed was not accepted). - The mechanism then collects a snapshot of all blocks it hasn't accepted yet, and the amount of stake that has accepted this block. - The mechanism then waits for its next invocation, in order to see if it has accepted blocks correlated with enough stake. - If there is too much stake that has accepted blocks by other nodes correlated to it that the node hasn't accepted, then the mechanism announces the node is behind, and returns the time period between the two invocations. - The mechanism sums the total time it has detected the node is behind, until a sampling concludes it is not behind, and then the total time is nullified. Signed-off-by: Yacov Manevich --- chains/manager.go | 6 +- ids/node_weight.go | 9 + snow/engine/common/tracker/peers.go | 26 +- snow/engine/common/tracker/peers_test.go | 25 ++ snow/engine/snowman/engine_decorator.go | 42 +++ snow/engine/snowman/engine_decorator_test.go | 74 ++++ snow/engine/snowman/metrics.go | 6 + snow/engine/snowman/straggler_detect.go | 307 +++++++++++++++ snow/engine/snowman/straggler_detect_test.go | 376 +++++++++++++++++++ snow/networking/handler/health.go | 10 +- 10 files changed, 868 insertions(+), 13 deletions(-) create mode 100644 ids/node_weight.go create mode 100644 snow/engine/snowman/engine_decorator.go create mode 100644 snow/engine/snowman/engine_decorator_test.go create mode 100644 snow/engine/snowman/straggler_detect.go create mode 100644 snow/engine/snowman/straggler_detect_test.go diff --git a/chains/manager.go b/chains/manager.go index 61e40f789ddf..9a267db00fc9 100644 --- a/chains/manager.go +++ b/chains/manager.go @@ -1336,12 +1336,14 @@ func (m *manager) createSnowmanChain( Consensus: consensus, PartialSync: m.PartialSyncPrimaryNetwork && ctx.ChainID == constants.PlatformChainID, } - var engine common.Engine - engine, err = smeng.New(engineConfig) + + sme, err := smeng.New(engineConfig) if err != nil { return nil, fmt.Errorf("error initializing snowman engine: %w", err) } + engine := smeng.NewDecoratedEngine(sme, time.Now, func(_ time.Duration) {}) + if m.TracingEnabled { engine = common.TraceEngine(engine, m.Tracer) } diff --git a/ids/node_weight.go b/ids/node_weight.go new file mode 100644 index 000000000000..21309586ca2a --- /dev/null +++ b/ids/node_weight.go @@ -0,0 +1,9 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package ids + +type NodeWeight struct { + Node NodeID + Weight uint64 +} diff --git a/snow/engine/common/tracker/peers.go b/snow/engine/common/tracker/peers.go index 37bf7b10f026..65dda6f7d1ff 100644 --- a/snow/engine/common/tracker/peers.go +++ b/snow/engine/common/tracker/peers.go @@ -9,7 +9,6 @@ import ( "sync" "github.com/prometheus/client_golang/prometheus" - "golang.org/x/exp/maps" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/validators" @@ -37,10 +36,10 @@ type Peers interface { SampleValidator() (ids.NodeID, bool) // GetValidators returns the set of all validators // known to this peer manager - GetValidators() set.Set[ids.NodeID] + GetValidators() set.Set[ids.NodeWeight] // ConnectedValidators returns the set of all validators // that are currently connected - ConnectedValidators() set.Set[ids.NodeID] + ConnectedValidators() set.Set[ids.NodeWeight] } type lockedPeers struct { @@ -112,14 +111,14 @@ func (p *lockedPeers) SampleValidator() (ids.NodeID, bool) { return p.peers.SampleValidator() } -func (p *lockedPeers) GetValidators() set.Set[ids.NodeID] { +func (p *lockedPeers) GetValidators() set.Set[ids.NodeWeight] { p.lock.RLock() defer p.lock.RUnlock() return p.peers.GetValidators() } -func (p *lockedPeers) ConnectedValidators() set.Set[ids.NodeID] { +func (p *lockedPeers) ConnectedValidators() set.Set[ids.NodeWeight] { p.lock.RLock() defer p.lock.RUnlock() @@ -272,14 +271,21 @@ func (p *peerData) SampleValidator() (ids.NodeID, bool) { return p.connectedValidators.Peek() } -func (p *peerData) GetValidators() set.Set[ids.NodeID] { - return set.Of(maps.Keys(p.validators)...) +func (p *peerData) GetValidators() set.Set[ids.NodeWeight] { + res := set.NewSet[ids.NodeWeight](len(p.validators)) + for k, v := range p.validators { + res.Add(ids.NodeWeight{Node: k, Weight: v}) + } + return res } -func (p *peerData) ConnectedValidators() set.Set[ids.NodeID] { +func (p *peerData) ConnectedValidators() set.Set[ids.NodeWeight] { // The set is copied to avoid future changes from being reflected in the // returned set. - copied := set.NewSet[ids.NodeID](len(p.connectedValidators)) - copied.Union(p.connectedValidators) + copied := set.NewSet[ids.NodeWeight](len(p.connectedValidators)) + for _, vdrID := range p.connectedValidators.List() { + weight := p.validators[vdrID] + copied.Add(ids.NodeWeight{Node: vdrID, Weight: weight}) + } return copied } diff --git a/snow/engine/common/tracker/peers_test.go b/snow/engine/common/tracker/peers_test.go index 1ed2daf6575b..ac577a399b48 100644 --- a/snow/engine/common/tracker/peers_test.go +++ b/snow/engine/common/tracker/peers_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/version" ) @@ -40,3 +41,27 @@ func TestPeers(t *testing.T) { require.NoError(p.Disconnected(context.Background(), nodeID)) require.Zero(p.ConnectedWeight()) } + +func TestConnectedValidators(t *testing.T) { + require := require.New(t) + + nodeID1 := ids.GenerateTestNodeID() + nodeID2 := ids.GenerateTestNodeID() + + p := NewPeers() + + p.OnValidatorAdded(nodeID1, nil, ids.Empty, 5) + p.OnValidatorAdded(nodeID2, nil, ids.Empty, 6) + + require.NoError(p.Connected(context.Background(), nodeID1, version.CurrentApp)) + require.Equal(uint64(5), p.ConnectedWeight()) + + require.NoError(p.Connected(context.Background(), nodeID2, version.CurrentApp)) + require.Equal(uint64(11), p.ConnectedWeight()) + require.True(set.Of(ids.NodeWeight{Node: nodeID1, Weight: 5}, ids.NodeWeight{Node: nodeID2, Weight: 6}).Equals(p.GetValidators())) + require.True(set.Of(ids.NodeWeight{Node: nodeID1, Weight: 5}, ids.NodeWeight{Node: nodeID2, Weight: 6}).Equals(p.ConnectedValidators())) + + require.NoError(p.Disconnected(context.Background(), nodeID2)) + require.True(set.Of(ids.NodeWeight{Node: nodeID1, Weight: 5}, ids.NodeWeight{Node: nodeID2, Weight: 6}).Equals(p.GetValidators())) + require.True(set.Of(ids.NodeWeight{Node: nodeID1, Weight: 5}).Equals(p.ConnectedValidators())) +} diff --git a/snow/engine/snowman/engine_decorator.go b/snow/engine/snowman/engine_decorator.go new file mode 100644 index 000000000000..e1b2bafdb69e --- /dev/null +++ b/snow/engine/snowman/engine_decorator.go @@ -0,0 +1,42 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package snowman + +import ( + "context" + "time" + + "go.uber.org/zap" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/engine/common" +) + +type DecoratedEngine struct { + *Engine + sd *stragglerDetector + f func(time.Duration) +} + +func NewDecoratedEngine(e *Engine, time func() time.Time, f func(time.Duration)) common.Engine { + minConfRatio := float64(e.Params.AlphaConfidence) / float64(e.Params.K) + sd := newStragglerDetector(time, e.Config.Ctx.Log, minConfRatio, e.Consensus.LastAccepted, + e.Config.ConnectedValidators.ConnectedValidators, e.Config.ConnectedValidators.ConnectedPercent, + e.Consensus.Processing, e.acceptedFrontiers.LastAccepted) + return &DecoratedEngine{ + Engine: e, + f: f, + sd: sd, + } +} + +func (de *DecoratedEngine) Chits(ctx context.Context, nodeID ids.NodeID, requestID uint32, preferredID ids.ID, preferredIDAtHeight ids.ID, acceptedID ids.ID) error { + behindDuration := de.sd.CheckIfWeAreStragglingBehind() + if behindDuration > 0 { + de.Engine.Config.Ctx.Log.Info("We are behind the rest of the network", zap.Float64("seconds", behindDuration.Seconds())) + } + de.Engine.metrics.stragglingDuration.Set(float64(behindDuration)) + de.f(behindDuration) + return de.Engine.Chits(ctx, nodeID, requestID, preferredID, preferredIDAtHeight, acceptedID) +} diff --git a/snow/engine/snowman/engine_decorator_test.go b/snow/engine/snowman/engine_decorator_test.go new file mode 100644 index 000000000000..454b0b13eaae --- /dev/null +++ b/snow/engine/snowman/engine_decorator_test.go @@ -0,0 +1,74 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package snowman + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/consensus/snowman" + "github.com/ava-labs/avalanchego/snow/consensus/snowman/snowmantest" +) + +func TestEngineStragglerDetector(t *testing.T) { + require := require.New(t) + + fakeClock := make(chan time.Time, 1) + + conf := DefaultConfig(t) + peerID, _, sender, vm, engine := setup(t, conf) + + parent := snowmantest.BuildChild(snowmantest.Genesis) + require.NoError(conf.Consensus.Add(parent)) + + listenerShouldInvokeWith := []time.Duration{0, 0, time.Second * 2} + + fakeTime := func() time.Time { + select { + case now := <-fakeClock: + return now + default: + require.Fail("should have a time.Time in the channel") + return time.Time{} + } + } + + f := func(duration time.Duration) { + require.Equal(listenerShouldInvokeWith[0], duration) + listenerShouldInvokeWith = listenerShouldInvokeWith[1:] + } + + decoratedEngine := NewDecoratedEngine(engine, fakeTime, f) + + vm.GetBlockF = func(_ context.Context, blkID ids.ID) (snowman.Block, error) { + switch blkID { + case snowmantest.GenesisID: + return snowmantest.Genesis, nil + default: + return nil, errUnknownBlock + } + } + + sender.SendGetF = func(_ context.Context, _ ids.NodeID, _ uint32, _ ids.ID) { + } + vm.ParseBlockF = func(_ context.Context, _ []byte) (snowman.Block, error) { + require.FailNow("should not be called") + return nil, nil + } + + now := time.Now() + fakeClock <- now + require.NoError(decoratedEngine.Chits(context.Background(), peerID, 0, parent.ID(), parent.ID(), parent.ID())) + now = now.Add(time.Second * 2) + fakeClock <- now + require.NoError(decoratedEngine.Chits(context.Background(), peerID, 0, parent.ID(), parent.ID(), parent.ID())) + now = now.Add(time.Second * 2) + fakeClock <- now + require.NoError(decoratedEngine.Chits(context.Background(), peerID, 0, parent.ID(), parent.ID(), parent.ID())) + require.Empty(listenerShouldInvokeWith) +} diff --git a/snow/engine/snowman/metrics.go b/snow/engine/snowman/metrics.go index 922b18200d47..68856ba1054b 100644 --- a/snow/engine/snowman/metrics.go +++ b/snow/engine/snowman/metrics.go @@ -23,6 +23,7 @@ type metrics struct { numBlocked prometheus.Gauge numBlockers prometheus.Gauge numNonVerifieds prometheus.Gauge + stragglingDuration prometheus.Gauge numBuilt prometheus.Counter numBuildsFailed prometheus.Counter numUselessPutBytes prometheus.Counter @@ -41,6 +42,10 @@ type metrics struct { func newMetrics(reg prometheus.Registerer) (*metrics, error) { errs := wrappers.Errs{} m := &metrics{ + stragglingDuration: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "straggling_duration", + Help: "For how long we have been straggling behind the rest, in nano-seconds.", + }), bootstrapFinished: prometheus.NewGauge(prometheus.GaugeOpts{ Name: "bootstrap_finished", Help: "Whether or not bootstrap process has completed. 1 is success, 0 is fail or ongoing.", @@ -128,6 +133,7 @@ func newMetrics(reg prometheus.Registerer) (*metrics, error) { m.issued.WithLabelValues(unknownSource) errs.Add( + reg.Register(m.stragglingDuration), reg.Register(m.bootstrapFinished), reg.Register(m.numRequests), reg.Register(m.numBlocked), diff --git a/snow/engine/snowman/straggler_detect.go b/snow/engine/snowman/straggler_detect.go new file mode 100644 index 000000000000..736aa9f4100b --- /dev/null +++ b/snow/engine/snowman/straggler_detect.go @@ -0,0 +1,307 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package snowman + +import ( + "fmt" + "time" + + "go.uber.org/zap" + "golang.org/x/exp/maps" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/set" + + safemath "github.com/ava-labs/avalanchego/utils/math" +) + +const ( + minStragglerCheckInterval = time.Second + stakeThresholdForStragglerSuspicion = 0.75 +) + +type stragglerDetectorConfig struct { + // getTime returns the current time + getTime func() time.Time + + // minStragglerCheckInterval determines how frequently we are allowed to check if we are stragglers. + minStragglerCheckInterval time.Duration + + // log logs events + log logging.Logger + + // minConfirmationThreshold is the minimum percentage that below it, we do not check if we are stragglers. + minConfirmationThreshold float64 + + // connectedPercent returns the percent of connected nodes. + connectedPercent func() float64 + + // connectedValidators returns a set of tuples of NodeID and corresponding weight. + connectedValidators func() set.Set[ids.NodeWeight] + + // lastAcceptedByNodeID returns the last reported height a node has reported, or false if it is unknown. + lastAcceptedByNodeID func(id ids.NodeID) (ids.ID, bool) + + // processing returns whether this block ID is known and its descendants have not yet been accepted by consensus. + // This means that when the last accepted block is given as input, true is returned, as by definition + // its descendants have not been accepted by consensus, but this block is known. + // For any block ID belonging to an ancestor of the last accepted block, false is returned, + // as the last accepted block has been accepted by consensus. + processing func(id ids.ID) bool + + // lastAccepted returns the last accepted block of this node. + lastAccepted func() ids.ID + + // getSnapshot returns a snapshot of the network's nodes and their last accepted blocks, + // or false if it fails from some reason. + getSnapshot func() (snapshot, bool) + + // haveWeFailedCatchingUp returns whether we have not replicated enough blocks of the given snapshot + haveWeFailedCatchingUp func(snapshot) bool +} + +type stragglerDetector struct { + stragglerDetectorConfig + + // continuousStragglingPeriod defines the time we have been straggling continuously. + continuousStragglingPeriod time.Duration + + // previousStragglerCheckTime is the last time we checked whether + // our block height is behind the rest of the network + previousStragglerCheckTime time.Time + + // prevSnapshot is the snapshot from a past iteration. + prevSnapshot snapshot +} + +func newStragglerDetector( + getTime func() time.Time, + log logging.Logger, + minConfirmationThreshold float64, + lastAccepted func() (ids.ID, uint64), + connectedValidators func() set.Set[ids.NodeWeight], + connectedPercent func() float64, + processing func(id ids.ID) bool, + lastAcceptedByNodeID func(ids.NodeID) (ids.ID, bool), +) *stragglerDetector { + sd := &stragglerDetector{ + stragglerDetectorConfig: stragglerDetectorConfig{ + lastAccepted: dropHeight(lastAccepted), + processing: processing, + minStragglerCheckInterval: minStragglerCheckInterval, + log: log, + connectedValidators: connectedValidators, + connectedPercent: connectedPercent, + minConfirmationThreshold: minConfirmationThreshold, + lastAcceptedByNodeID: lastAcceptedByNodeID, + getTime: getTime, + }, + } + + sd.getSnapshot = sd.getNetworkSnapshot + sd.haveWeFailedCatchingUp = sd.failedCatchingUp + + return sd +} + +// CheckIfWeAreStragglingBehind returns for how long our ledger is behind the rest +// of the nodes in the network. If we are not behind, zero is returned. +func (sd *stragglerDetector) CheckIfWeAreStragglingBehind() time.Duration { + now := sd.getTime() + if sd.previousStragglerCheckTime.IsZero() { + sd.previousStragglerCheckTime = now + return 0 + } + + // Don't check too often, only once in every minStragglerCheckInterval + if sd.previousStragglerCheckTime.Add(sd.minStragglerCheckInterval).After(now) { + return 0 + } + + defer func() { + sd.previousStragglerCheckTime = now + }() + + if sd.prevSnapshot.isEmpty() { + snapshot, ok := sd.getSnapshot() + if !ok { + sd.log.Trace("No node snapshot obtained") + sd.continuousStragglingPeriod = 0 + } + sd.prevSnapshot = snapshot + } else { + if sd.haveWeFailedCatchingUp(sd.prevSnapshot) { + timeSinceLastCheck := now.Sub(sd.previousStragglerCheckTime) + sd.continuousStragglingPeriod += timeSinceLastCheck + } else { + sd.continuousStragglingPeriod = 0 + } + sd.prevSnapshot = snapshot{} + } + + return sd.continuousStragglingPeriod +} + +func (sd *stragglerDetector) failedCatchingUp(s snapshot) bool { + totalValidatorWeight, nodeWeights2Blocks := s.totalValidatorWeight, s.nodeWeights2Blocks + + var processingWeight uint64 + for nw, lastAccepted := range nodeWeights2Blocks { + if sd.processing(lastAccepted) { + newProcessingWeight, err := safemath.Add(processingWeight, nw.Weight) + if err != nil { + sd.log.Error("Cumulative weight overflow", zap.Uint64("cumulative", processingWeight), zap.Uint64("added", nw.Weight)) + return false + } + processingWeight = newProcessingWeight + } + } + + sd.log.Trace("Counted total weight that accepted blocks we're still processing", zap.Uint64("weight", processingWeight)) + + ratio := float64(processingWeight) / float64(totalValidatorWeight) + + if ratio > stakeThresholdForStragglerSuspicion { + sd.log.Trace("We are straggling behind", zap.Float64("ratio", ratio)) + return true + } + + sd.log.Trace("Nodes ahead of us:", zap.Float64("ratio", ratio)) + + return false +} + +func (sd *stragglerDetector) getNetworkSnapshot() (snapshot, bool) { + nodeWeight2lastAccepted, totalValidatorWeight, _ := sd.getNetworkInfo() + if len(nodeWeight2lastAccepted) == 0 { + return snapshot{}, false + } + + ourLastAcceptedBlock := sd.lastAccepted() + + prevLastAcceptedCount := len(nodeWeight2lastAccepted) + for k, v := range nodeWeight2lastAccepted { + if ourLastAcceptedBlock.Compare(v) == 0 { + delete(nodeWeight2lastAccepted, k) + } + } + newLastAcceptedCount := len(nodeWeight2lastAccepted) + + sd.log.Trace("Excluding nodes with our own height", zap.Int("prev", prevLastAcceptedCount), zap.Int("new", newLastAcceptedCount)) + + // Ensure we have collected last accepted blocks that are not our own last accepted block + // for at least 80% stake of the total weight we are connected to. + + totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeight2lastAccepted.totalWeight() + if err != nil { + sd.log.Error("Failed computing total weight", zap.Error(err)) + return snapshot{}, false + } + + ratio := float64(totalWeightWeKnowItsLastAcceptedBlock) / float64(totalValidatorWeight) + + if ratio < 0.8 { + sd.log.Trace("Most stake we're connected to has the same height as we do", + zap.Float64("ratio", ratio)) + return snapshot{}, false + } + + snap := snapshot{ + nodeWeights2Blocks: nodeWeight2lastAccepted, + totalValidatorWeight: totalValidatorWeight, + } + + if sd.haveWeFailedCatchingUp(snap) { + return snap, true + } + + return snapshot{}, false +} + +func (sd *stragglerDetector) getNetworkInfo() (nodeWeights2Blocks, uint64, uint64) { + ratio := sd.connectedPercent() + if ratio < sd.minConfirmationThreshold { + // We don't know for sure whether we're behind or not. + // Even if we're behind, it's pointless to act before we have established + // connectivity with enough validators. + sd.log.Verbo("not enough connected stake to determine network info", zap.Float64("ratio", ratio)) + return nil, 0, 0 + } + + validators := nodeWeights(sd.connectedValidators().List()) + + nodeWeight2lastAccepted := make(nodeWeights2Blocks, len(validators)) + + for _, vdr := range validators { + lastAccepted, ok := sd.lastAcceptedByNodeID(vdr.Node) + if !ok { + continue + } + nodeWeight2lastAccepted[vdr] = lastAccepted + } + + totalValidatorWeight, err := validators.totalWeight() + if err != nil { + sd.log.Error("Failed computing total weight", zap.Error(err)) + return nil, 0, 0 + } + + totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeight2lastAccepted.totalWeight() + if err != nil { + sd.log.Error("Failed computing total weight", zap.Error(err)) + return nil, 0, 0 + } + + if totalValidatorWeight == 0 { + sd.log.Trace("Connected to zero weight") + return nil, 0, 0 + } + + ratio = float64(totalWeightWeKnowItsLastAcceptedBlock) / float64(totalValidatorWeight) + + // Ensure we have collected last accepted blocks for at least 80% stake of the total weight we are connected to. + if ratio < 0.8 { + sd.log.Trace("Not collected enough information about last accepted blocks for the validators we are connected to", + zap.Float64("ratio", ratio)) + return nil, 0, 0 + } + return nodeWeight2lastAccepted, totalValidatorWeight, totalWeightWeKnowItsLastAcceptedBlock +} + +type snapshot struct { + totalValidatorWeight uint64 + nodeWeights2Blocks nodeWeights2Blocks +} + +func (s snapshot) isEmpty() bool { + return s.totalValidatorWeight == 0 || len(s.nodeWeights2Blocks) == 0 +} + +type nodeWeights2Blocks map[ids.NodeWeight]ids.ID + +func (nw2b nodeWeights2Blocks) totalWeight() (uint64, error) { + return nodeWeights(maps.Keys(nw2b)).totalWeight() +} + +func dropHeight(f func() (ids.ID, uint64)) func() ids.ID { + return func() ids.ID { + id, _ := f() + return id + } +} + +type nodeWeights []ids.NodeWeight + +func (nws nodeWeights) totalWeight() (uint64, error) { + var weight uint64 + for _, nw := range nws { + newWeight, err := safemath.Add(weight, nw.Weight) + if err != nil { + return 0, fmt.Errorf("cumulative weight: %d, tried to add %d: %w", weight, nw.Weight, err) + } + weight = newWeight + } + return weight, nil +} diff --git a/snow/engine/snowman/straggler_detect_test.go b/snow/engine/snowman/straggler_detect_test.go new file mode 100644 index 000000000000..5d68711b25ff --- /dev/null +++ b/snow/engine/snowman/straggler_detect_test.go @@ -0,0 +1,376 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package snowman + +import ( + "math" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/set" + + safemath "github.com/ava-labs/avalanchego/utils/math" +) + +func TestNodeWeights(t *testing.T) { + nws := nodeWeights{ + {Weight: 100}, + {Weight: 50}, + } + + total, err := nws.totalWeight() + require.NoError(t, err) + require.Equal(t, uint64(150), total) +} + +func TestNodeWeightsOverflow(t *testing.T) { + nws := nodeWeights{ + {Weight: math.MaxUint64 - 100}, + {Weight: 110}, + } + + total, err := nws.totalWeight() + require.ErrorIs(t, err, safemath.ErrOverflow) + require.Equal(t, uint64(0), total) +} + +func TestNodeWeights2Blocks(t *testing.T) { + nw2b := nodeWeights2Blocks{ + ids.NodeWeight{Weight: 5}: ids.Empty, + ids.NodeWeight{Weight: 10}: ids.Empty, + } + + total, err := nw2b.totalWeight() + require.NoError(t, err) + require.Equal(t, uint64(15), total) +} + +func TestGetNetworkSnapshot(t *testing.T) { + n1, err := ids.NodeIDFromString("NodeID-N5gc5soT3Gpr98NKpqvQQG2SgGrVPL64w") + require.NoError(t, err) + + n2, err := ids.NodeIDFromString("NodeID-NpagUxt6KQiwPch9Sd4osv8kD1TZnkjdk") + require.NoError(t, err) + + connectedValidators := func(s []ids.NodeWeight) func() set.Set[ids.NodeWeight] { + return func() set.Set[ids.NodeWeight] { + var set set.Set[ids.NodeWeight] + for _, nw := range s { + set.Add(nw) + } + return set + } + } + + for _, testCase := range []struct { + description string + lastAccepted ids.ID + lastAcceptedFromNodes map[ids.NodeID]ids.ID + processing map[ids.ID]struct{} + connectedValidators func() set.Set[ids.NodeWeight] + connectedPercent float64 + expectedSnapshot snapshot + expectedOK bool + expectedLogged string + }{ + { + description: "not enough connected validators", + connectedValidators: connectedValidators([]ids.NodeWeight{}), + expectedLogged: "not enough connected stake to determine network info", + }, + { + description: "connected to zero weight", + connectedPercent: 1.0, + connectedValidators: connectedValidators([]ids.NodeWeight{}), + expectedLogged: "Connected to zero weight", + }, + { + description: "not enough info", + connectedPercent: 1.0, + connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 1, Node: n1}, {Weight: 999999, Node: n2}}), + lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ + n1: {0x1}, + }, + expectedLogged: "Not collected enough information about last accepted blocks", + }, + { + description: "we're in sync", + connectedPercent: 1.0, + connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, Node: n1}}), + lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ + n1: {0x1}, + }, + lastAccepted: ids.ID{0x1}, + expectedLogged: "Most stake we're connected to has the same height as we do", + }, + { + description: "we're behind", + connectedPercent: 1.0, + connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, Node: n1}}), + lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ + n1: {0x1}, + }, + processing: map[ids.ID]struct{}{{0x1}: {}}, + lastAccepted: ids.ID{0x0}, + expectedSnapshot: snapshot{totalValidatorWeight: 999999, nodeWeights2Blocks: nodeWeights2Blocks{ + ids.NodeWeight{Node: n1, Weight: 999999}: {0x1}, + }}, + expectedOK: true, + }, + } { + t.Run(testCase.description, func(t *testing.T) { + var buff logBuffer + log := logging.NewLogger("", logging.NewWrappedCore(logging.Verbo, &buff, logging.Plain.ConsoleEncoder())) + + sd := newStragglerDetector(nil, log, 0.75, + func() (ids.ID, uint64) { + return testCase.lastAccepted, 0 + }, + testCase.connectedValidators, func() float64 { return testCase.connectedPercent }, + func(id ids.ID) bool { + _, ok := testCase.processing[id] + return ok + }, + func(vdr ids.NodeID) (ids.ID, bool) { + id, ok := testCase.lastAcceptedFromNodes[vdr] + return id, ok + }) + + snapshot, ok := sd.getNetworkSnapshot() + require.Equal(t, testCase.expectedSnapshot, snapshot) + require.Equal(t, testCase.expectedOK, ok) + require.Contains(t, buff.String(), testCase.expectedLogged) + }) + } +} + +func TestFailedCatchingUp(t *testing.T) { + n1, err := ids.NodeIDFromString("NodeID-N5gc5soT3Gpr98NKpqvQQG2SgGrVPL64w") + require.NoError(t, err) + + n2, err := ids.NodeIDFromString("NodeID-NpagUxt6KQiwPch9Sd4osv8kD1TZnkjdk") + require.NoError(t, err) + + for _, testCase := range []struct { + description string + lastAccepted ids.ID + lastAcceptedFromNodes map[ids.NodeID]ids.ID + processing map[ids.ID]struct{} + connectedValidators []ids.NodeWeight + connectedPercent float64 + input snapshot + expected bool + expectedLogged string + }{ + { + description: "stake overflow", + input: snapshot{ + nodeWeights2Blocks: nodeWeights2Blocks{ + ids.NodeWeight{Node: n1, Weight: math.MaxUint64 - 10}: ids.ID{0x1}, + ids.NodeWeight{Node: n2, Weight: 11}: ids.ID{0x2}, + }, + }, + processing: map[ids.ID]struct{}{ + {0x1}: {}, + {0x2}: {}, + }, + expectedLogged: "Cumulative weight overflow", + }, + { + description: "Straggling behind stake minority", + input: snapshot{ + totalValidatorWeight: 100, nodeWeights2Blocks: nodeWeights2Blocks{ + ids.NodeWeight{Node: n1, Weight: 25}: ids.ID{0x1}, + ids.NodeWeight{Node: n2, Weight: 50}: ids.ID{0x2}, + }, + }, + processing: map[ids.ID]struct{}{ + {0x1}: {}, + {0x2}: {}, + }, + expectedLogged: "Nodes ahead of us", + }, + { + description: "Straggling behind stake majority", + input: snapshot{ + totalValidatorWeight: 100, nodeWeights2Blocks: nodeWeights2Blocks{ + ids.NodeWeight{Node: n1, Weight: 26}: ids.ID{0x1}, + ids.NodeWeight{Node: n2, Weight: 50}: ids.ID{0x2}, + }, + }, + processing: map[ids.ID]struct{}{ + {0x1}: {}, + {0x2}: {}, + }, + expectedLogged: "We are straggling behind", + expected: true, + }, + { + description: "In sync with the majority", + input: snapshot{ + totalValidatorWeight: 100, nodeWeights2Blocks: nodeWeights2Blocks{ + ids.NodeWeight{Node: n1, Weight: 75}: ids.ID{0x1}, + ids.NodeWeight{Node: n2, Weight: 25}: ids.ID{0x2}, + }, + }, + processing: map[ids.ID]struct{}{ + {0x2}: {}, + }, + expectedLogged: "Nodes ahead of us", + }, + } { + t.Run(testCase.description, func(t *testing.T) { + var buff logBuffer + log := logging.NewLogger("", logging.NewWrappedCore(logging.Verbo, &buff, logging.Plain.ConsoleEncoder())) + + sd := newStragglerDetector(nil, log, 0.75, + func() (ids.ID, uint64) { + return testCase.lastAccepted, 0 + }, + func() set.Set[ids.NodeWeight] { + var set set.Set[ids.NodeWeight] + for _, nw := range testCase.connectedValidators { + set.Add(nw) + } + return set + }, func() float64 { return testCase.connectedPercent }, + func(id ids.ID) bool { + _, ok := testCase.processing[id] + return ok + }, + func(vdr ids.NodeID) (ids.ID, bool) { + id, ok := testCase.lastAcceptedFromNodes[vdr] + return id, ok + }) + + require.Equal(t, testCase.expected, sd.failedCatchingUp(testCase.input)) + require.Contains(t, buff.String(), testCase.expectedLogged) + }) + } +} + +func TestCheckIfWeAreStragglingBehind(t *testing.T) { + fakeClock := make(chan time.Time, 1) + + snapshots := make(chan snapshot, 1) + assertNoSnapshotsRemain := func() { + select { + case <-snapshots: + require.Fail(t, "Should not have any snapshots in standby") + default: + } + } + nonEmptySnap := snapshot{ + totalValidatorWeight: 100, + nodeWeights2Blocks: nodeWeights2Blocks{ + ids.NodeWeight{Weight: 100}: ids.Empty, + }, + } + + var haveWeFailedCatchingUpReturns bool + + var buff logBuffer + log := logging.NewLogger("", logging.NewWrappedCore(logging.Verbo, &buff, logging.Plain.ConsoleEncoder())) + + sd := stragglerDetector{ + stragglerDetectorConfig: stragglerDetectorConfig{ + minStragglerCheckInterval: time.Second, + getTime: func() time.Time { + now := <-fakeClock + return now + }, + log: log, + getSnapshot: func() (snapshot, bool) { + s := <-snapshots + return s, !s.isEmpty() + }, + haveWeFailedCatchingUp: func(_ snapshot) bool { + return haveWeFailedCatchingUpReturns + }, + }, + } + + fakeTime := time.Now() + + for _, testCase := range []struct { + description string + timeAdvanced time.Duration + evalExtraAssertions func() + expectedStragglingTime time.Duration + snapshotsRead []snapshot + haveWeFailedCatchingUpReturns bool + }{ + { + description: "First invocation only sets the time", + evalExtraAssertions: func() {}, + }, + { + description: "Should not check yet, as it is not time yet", + timeAdvanced: time.Millisecond * 500, + evalExtraAssertions: func() {}, + }, + { + description: "Advance time some more, so now we should check", + timeAdvanced: time.Millisecond * 501, + snapshotsRead: []snapshot{{}}, + evalExtraAssertions: func() { + require.Contains(t, buff.String(), "No node snapshot obtained") + }, + }, + { + description: "Advance time some more to the first check where the snapshot isn't empty", + timeAdvanced: time.Second * 2, + snapshotsRead: []snapshot{nonEmptySnap}, + evalExtraAssertions: func() { + require.Empty(t, buff.String()) + }, + }, + { + description: "The next check returns we have failed catching up.", + timeAdvanced: time.Second * 2, + expectedStragglingTime: time.Second * 2, + haveWeFailedCatchingUpReturns: true, + evalExtraAssertions: func() { + require.Empty(t, sd.prevSnapshot) + }, + }, + { + description: "The third snapshot is due to a fresh check", + timeAdvanced: time.Second * 2, + snapshotsRead: []snapshot{nonEmptySnap}, + // We carry over the total straggling time from previous testCase to this check, + // as we need the next check to nullify it. + expectedStragglingTime: time.Second * 2, + evalExtraAssertions: func() {}, + }, + { + description: "The fourth check returns we have succeeded in catching up", + timeAdvanced: time.Second * 2, + evalExtraAssertions: func() {}, + }, + } { + t.Run(testCase.description, func(t *testing.T) { + fakeTime = fakeTime.Add(testCase.timeAdvanced) + fakeClock <- fakeTime + + // Load the snapshot expected to be retrieved in this testCase, if applicable. + if len(testCase.snapshotsRead) > 0 { + snapshots <- testCase.snapshotsRead[0] + } + + haveWeFailedCatchingUpReturns = testCase.haveWeFailedCatchingUpReturns + require.Equal(t, testCase.expectedStragglingTime, sd.CheckIfWeAreStragglingBehind()) + testCase.evalExtraAssertions() + + // Cleanup the log buffer, and make sure no snapshots remain for next testCase. + buff.Reset() + assertNoSnapshotsRemain() + haveWeFailedCatchingUpReturns = false + }) + } +} diff --git a/snow/networking/handler/health.go b/snow/networking/handler/health.go index 0dbcb844fb95..fbbc8113e2d1 100644 --- a/snow/networking/handler/health.go +++ b/snow/networking/handler/health.go @@ -66,5 +66,13 @@ func (h *handler) getDisconnectedValidators() set.Set[ids.NodeID] { connectedVdrs := h.peerTracker.ConnectedValidators() // vdrs - connectedVdrs is equal to the disconnectedVdrs vdrs.Difference(connectedVdrs) - return vdrs + return trimWeights(vdrs) +} + +func trimWeights(weights set.Set[ids.NodeWeight]) set.Set[ids.NodeID] { + var res set.Set[ids.NodeID] + for _, nw := range weights.List() { + res.Add(nw.Node) + } + return res } From dac03824d7f4e918417b7808734c72c5610b453b Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Tue, 24 Sep 2024 21:00:06 +0200 Subject: [PATCH 2/6] Address code review comments Signed-off-by: Yacov Manevich --- snow/engine/snowman/straggler_detect.go | 49 ++++++++++---------- snow/engine/snowman/straggler_detect_test.go | 30 ++++++------ snow/networking/handler/health.go | 4 +- 3 files changed, 40 insertions(+), 43 deletions(-) diff --git a/snow/engine/snowman/straggler_detect.go b/snow/engine/snowman/straggler_detect.go index 736aa9f4100b..00653c7ab467 100644 --- a/snow/engine/snowman/straggler_detect.go +++ b/snow/engine/snowman/straggler_detect.go @@ -18,8 +18,9 @@ import ( ) const ( - minStragglerCheckInterval = time.Second - stakeThresholdForStragglerSuspicion = 0.75 + minStragglerCheckInterval = time.Second + stakeThresholdForStragglerSuspicion = 0.75 + minimumStakeThresholdRequiredForNetworkInfo = 0.8 ) type stragglerDetectorConfig struct { @@ -32,16 +33,16 @@ type stragglerDetectorConfig struct { // log logs events log logging.Logger - // minConfirmationThreshold is the minimum percentage that below it, we do not check if we are stragglers. + // minConfirmationThreshold is the minimum stake percentage that below it, we do not check if we are stragglers. minConfirmationThreshold float64 - // connectedPercent returns the percent of connected nodes. + // connectedPercent returns the stake percentage of connected nodes. connectedPercent func() float64 // connectedValidators returns a set of tuples of NodeID and corresponding weight. connectedValidators func() set.Set[ids.NodeWeight] - // lastAcceptedByNodeID returns the last reported height a node has reported, or false if it is unknown. + // lastAcceptedByNodeID returns the last accepted height a node has reported, or false if it is unknown. lastAcceptedByNodeID func(id ids.NodeID) (ids.ID, bool) // processing returns whether this block ID is known and its descendants have not yet been accepted by consensus. @@ -58,7 +59,7 @@ type stragglerDetectorConfig struct { // or false if it fails from some reason. getSnapshot func() (snapshot, bool) - // haveWeFailedCatchingUp returns whether we have not replicated enough blocks of the given snapshot + // haveWeFailedCatchingUp returns whether we have not replicated enough blocks of the given snapshot haveWeFailedCatchingUp func(snapshot) bool } @@ -174,27 +175,27 @@ func (sd *stragglerDetector) failedCatchingUp(s snapshot) bool { } func (sd *stragglerDetector) getNetworkSnapshot() (snapshot, bool) { - nodeWeight2lastAccepted, totalValidatorWeight, _ := sd.getNetworkInfo() - if len(nodeWeight2lastAccepted) == 0 { + nodeWeightToLastAccepted, totalValidatorWeight := sd.getNetworkInfo() + if len(nodeWeightToLastAccepted) == 0 { return snapshot{}, false } ourLastAcceptedBlock := sd.lastAccepted() - prevLastAcceptedCount := len(nodeWeight2lastAccepted) - for k, v := range nodeWeight2lastAccepted { + prevLastAcceptedCount := len(nodeWeightToLastAccepted) + for k, v := range nodeWeightToLastAccepted { if ourLastAcceptedBlock.Compare(v) == 0 { - delete(nodeWeight2lastAccepted, k) + delete(nodeWeightToLastAccepted, k) } } - newLastAcceptedCount := len(nodeWeight2lastAccepted) + newLastAcceptedCount := len(nodeWeightToLastAccepted) sd.log.Trace("Excluding nodes with our own height", zap.Int("prev", prevLastAcceptedCount), zap.Int("new", newLastAcceptedCount)) // Ensure we have collected last accepted blocks that are not our own last accepted block // for at least 80% stake of the total weight we are connected to. - totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeight2lastAccepted.totalWeight() + totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeightToLastAccepted.totalWeight() if err != nil { sd.log.Error("Failed computing total weight", zap.Error(err)) return snapshot{}, false @@ -209,7 +210,7 @@ func (sd *stragglerDetector) getNetworkSnapshot() (snapshot, bool) { } snap := snapshot{ - nodeWeights2Blocks: nodeWeight2lastAccepted, + nodeWeights2Blocks: nodeWeightToLastAccepted, totalValidatorWeight: totalValidatorWeight, } @@ -220,14 +221,14 @@ func (sd *stragglerDetector) getNetworkSnapshot() (snapshot, bool) { return snapshot{}, false } -func (sd *stragglerDetector) getNetworkInfo() (nodeWeights2Blocks, uint64, uint64) { +func (sd *stragglerDetector) getNetworkInfo() (nodeWeights2Blocks, uint64) { ratio := sd.connectedPercent() if ratio < sd.minConfirmationThreshold { // We don't know for sure whether we're behind or not. // Even if we're behind, it's pointless to act before we have established - // connectivity with enough validators. + // connectivity with enough validators. sd.log.Verbo("not enough connected stake to determine network info", zap.Float64("ratio", ratio)) - return nil, 0, 0 + return nil, 0 } validators := nodeWeights(sd.connectedValidators().List()) @@ -245,29 +246,29 @@ func (sd *stragglerDetector) getNetworkInfo() (nodeWeights2Blocks, uint64, uint6 totalValidatorWeight, err := validators.totalWeight() if err != nil { sd.log.Error("Failed computing total weight", zap.Error(err)) - return nil, 0, 0 + return nil, 0 } totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeight2lastAccepted.totalWeight() if err != nil { sd.log.Error("Failed computing total weight", zap.Error(err)) - return nil, 0, 0 + return nil, 0 } if totalValidatorWeight == 0 { sd.log.Trace("Connected to zero weight") - return nil, 0, 0 + return nil, 0 } ratio = float64(totalWeightWeKnowItsLastAcceptedBlock) / float64(totalValidatorWeight) - // Ensure we have collected last accepted blocks for at least 80% stake of the total weight we are connected to. - if ratio < 0.8 { + // Ensure we have collected last accepted blocks for at least 80% (or so) stake of the total weight we are connected to. + if ratio < minimumStakeThresholdRequiredForNetworkInfo { sd.log.Trace("Not collected enough information about last accepted blocks for the validators we are connected to", zap.Float64("ratio", ratio)) - return nil, 0, 0 + return nil, 0 } - return nodeWeight2lastAccepted, totalValidatorWeight, totalWeightWeKnowItsLastAcceptedBlock + return nodeWeight2lastAccepted, totalValidatorWeight } type snapshot struct { diff --git a/snow/engine/snowman/straggler_detect_test.go b/snow/engine/snowman/straggler_detect_test.go index 5d68711b25ff..238fed062036 100644 --- a/snow/engine/snowman/straggler_detect_test.go +++ b/snow/engine/snowman/straggler_detect_test.go @@ -51,11 +51,9 @@ func TestNodeWeights2Blocks(t *testing.T) { } func TestGetNetworkSnapshot(t *testing.T) { - n1, err := ids.NodeIDFromString("NodeID-N5gc5soT3Gpr98NKpqvQQG2SgGrVPL64w") - require.NoError(t, err) + n1 := ids.GenerateTestNodeID() - n2, err := ids.NodeIDFromString("NodeID-NpagUxt6KQiwPch9Sd4osv8kD1TZnkjdk") - require.NoError(t, err) + n2 := ids.GenerateTestNodeID() connectedValidators := func(s []ids.NodeWeight) func() set.Set[ids.NodeWeight] { return func() set.Set[ids.NodeWeight] { @@ -150,11 +148,9 @@ func TestGetNetworkSnapshot(t *testing.T) { } func TestFailedCatchingUp(t *testing.T) { - n1, err := ids.NodeIDFromString("NodeID-N5gc5soT3Gpr98NKpqvQQG2SgGrVPL64w") - require.NoError(t, err) + n1 := ids.GenerateTestNodeID() - n2, err := ids.NodeIDFromString("NodeID-NpagUxt6KQiwPch9Sd4osv8kD1TZnkjdk") - require.NoError(t, err) + n2 := ids.GenerateTestNodeID() for _, testCase := range []struct { description string @@ -300,25 +296,25 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { for _, testCase := range []struct { description string timeAdvanced time.Duration - evalExtraAssertions func() + evalExtraAssertions func(t *testing.T) expectedStragglingTime time.Duration snapshotsRead []snapshot haveWeFailedCatchingUpReturns bool }{ { description: "First invocation only sets the time", - evalExtraAssertions: func() {}, + evalExtraAssertions: func(_ *testing.T) {}, }, { description: "Should not check yet, as it is not time yet", timeAdvanced: time.Millisecond * 500, - evalExtraAssertions: func() {}, + evalExtraAssertions: func(_ *testing.T) {}, }, { description: "Advance time some more, so now we should check", timeAdvanced: time.Millisecond * 501, snapshotsRead: []snapshot{{}}, - evalExtraAssertions: func() { + evalExtraAssertions: func(t *testing.T) { require.Contains(t, buff.String(), "No node snapshot obtained") }, }, @@ -326,7 +322,7 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { description: "Advance time some more to the first check where the snapshot isn't empty", timeAdvanced: time.Second * 2, snapshotsRead: []snapshot{nonEmptySnap}, - evalExtraAssertions: func() { + evalExtraAssertions: func(t *testing.T) { require.Empty(t, buff.String()) }, }, @@ -335,7 +331,7 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { timeAdvanced: time.Second * 2, expectedStragglingTime: time.Second * 2, haveWeFailedCatchingUpReturns: true, - evalExtraAssertions: func() { + evalExtraAssertions: func(t *testing.T) { require.Empty(t, sd.prevSnapshot) }, }, @@ -346,12 +342,12 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { // We carry over the total straggling time from previous testCase to this check, // as we need the next check to nullify it. expectedStragglingTime: time.Second * 2, - evalExtraAssertions: func() {}, + evalExtraAssertions: func(_ *testing.T) {}, }, { description: "The fourth check returns we have succeeded in catching up", timeAdvanced: time.Second * 2, - evalExtraAssertions: func() {}, + evalExtraAssertions: func(_ *testing.T) {}, }, } { t.Run(testCase.description, func(t *testing.T) { @@ -365,7 +361,7 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { haveWeFailedCatchingUpReturns = testCase.haveWeFailedCatchingUpReturns require.Equal(t, testCase.expectedStragglingTime, sd.CheckIfWeAreStragglingBehind()) - testCase.evalExtraAssertions() + testCase.evalExtraAssertions(t) // Cleanup the log buffer, and make sure no snapshots remain for next testCase. buff.Reset() diff --git a/snow/networking/handler/health.go b/snow/networking/handler/health.go index fbbc8113e2d1..4c43e0ead003 100644 --- a/snow/networking/handler/health.go +++ b/snow/networking/handler/health.go @@ -66,10 +66,10 @@ func (h *handler) getDisconnectedValidators() set.Set[ids.NodeID] { connectedVdrs := h.peerTracker.ConnectedValidators() // vdrs - connectedVdrs is equal to the disconnectedVdrs vdrs.Difference(connectedVdrs) - return trimWeights(vdrs) + return withoutWeights(vdrs) } -func trimWeights(weights set.Set[ids.NodeWeight]) set.Set[ids.NodeID] { +func withoutWeights(weights set.Set[ids.NodeWeight]) set.Set[ids.NodeID] { var res set.Set[ids.NodeID] for _, nw := range weights.List() { res.Add(nw.Node) From 4d014c05a14be54065a9c8bfe3e4efe237f4fde2 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Thu, 26 Sep 2024 17:22:25 +0200 Subject: [PATCH 3/6] Address code review comments II Signed-off-by: Yacov Manevich --- chains/manager.go | 2 +- snow/engine/snowman/engine_decorator.go | 8 +++--- snow/engine/snowman/engine_decorator_test.go | 2 +- snow/engine/snowman/straggler_detect.go | 29 ++++++++++---------- snow/engine/snowman/straggler_detect_test.go | 14 +++++----- 5 files changed, 28 insertions(+), 27 deletions(-) diff --git a/chains/manager.go b/chains/manager.go index 9a267db00fc9..5b3fbc986739 100644 --- a/chains/manager.go +++ b/chains/manager.go @@ -1342,7 +1342,7 @@ func (m *manager) createSnowmanChain( return nil, fmt.Errorf("error initializing snowman engine: %w", err) } - engine := smeng.NewDecoratedEngine(sme, time.Now, func(_ time.Duration) {}) + engine := smeng.NewDecoratedEngineWithStragglerDetector(sme, time.Now, func(_ time.Duration) {}) if m.TracingEnabled { engine = common.TraceEngine(engine, m.Tracer) diff --git a/snow/engine/snowman/engine_decorator.go b/snow/engine/snowman/engine_decorator.go index e1b2bafdb69e..abeb67d07612 100644 --- a/snow/engine/snowman/engine_decorator.go +++ b/snow/engine/snowman/engine_decorator.go @@ -13,25 +13,25 @@ import ( "github.com/ava-labs/avalanchego/snow/engine/common" ) -type DecoratedEngine struct { +type decoratedEngineWithStragglerDetector struct { *Engine sd *stragglerDetector f func(time.Duration) } -func NewDecoratedEngine(e *Engine, time func() time.Time, f func(time.Duration)) common.Engine { +func NewDecoratedEngineWithStragglerDetector(e *Engine, time func() time.Time, f func(time.Duration)) common.Engine { minConfRatio := float64(e.Params.AlphaConfidence) / float64(e.Params.K) sd := newStragglerDetector(time, e.Config.Ctx.Log, minConfRatio, e.Consensus.LastAccepted, e.Config.ConnectedValidators.ConnectedValidators, e.Config.ConnectedValidators.ConnectedPercent, e.Consensus.Processing, e.acceptedFrontiers.LastAccepted) - return &DecoratedEngine{ + return &decoratedEngineWithStragglerDetector{ Engine: e, f: f, sd: sd, } } -func (de *DecoratedEngine) Chits(ctx context.Context, nodeID ids.NodeID, requestID uint32, preferredID ids.ID, preferredIDAtHeight ids.ID, acceptedID ids.ID) error { +func (de *decoratedEngineWithStragglerDetector) Chits(ctx context.Context, nodeID ids.NodeID, requestID uint32, preferredID ids.ID, preferredIDAtHeight ids.ID, acceptedID ids.ID) error { behindDuration := de.sd.CheckIfWeAreStragglingBehind() if behindDuration > 0 { de.Engine.Config.Ctx.Log.Info("We are behind the rest of the network", zap.Float64("seconds", behindDuration.Seconds())) diff --git a/snow/engine/snowman/engine_decorator_test.go b/snow/engine/snowman/engine_decorator_test.go index 454b0b13eaae..367631249eaa 100644 --- a/snow/engine/snowman/engine_decorator_test.go +++ b/snow/engine/snowman/engine_decorator_test.go @@ -43,7 +43,7 @@ func TestEngineStragglerDetector(t *testing.T) { listenerShouldInvokeWith = listenerShouldInvokeWith[1:] } - decoratedEngine := NewDecoratedEngine(engine, fakeTime, f) + decoratedEngine := NewDecoratedEngineWithStragglerDetector(engine, fakeTime, f) vm.GetBlockF = func(_ context.Context, blkID ids.ID) (snowman.Block, error) { switch blkID { diff --git a/snow/engine/snowman/straggler_detect.go b/snow/engine/snowman/straggler_detect.go index 00653c7ab467..6f4a8f026aff 100644 --- a/snow/engine/snowman/straggler_detect.go +++ b/snow/engine/snowman/straggler_detect.go @@ -21,6 +21,7 @@ const ( minStragglerCheckInterval = time.Second stakeThresholdForStragglerSuspicion = 0.75 minimumStakeThresholdRequiredForNetworkInfo = 0.8 + knownStakeThresholdRequiredForAnalysis = 0.8 ) type stragglerDetectorConfig struct { @@ -146,10 +147,10 @@ func (sd *stragglerDetector) CheckIfWeAreStragglingBehind() time.Duration { } func (sd *stragglerDetector) failedCatchingUp(s snapshot) bool { - totalValidatorWeight, nodeWeights2Blocks := s.totalValidatorWeight, s.nodeWeights2Blocks + totalValidatorWeight, nodeWeightsToBlocks := s.totalValidatorWeight, s.nodeWeightsToBlocks var processingWeight uint64 - for nw, lastAccepted := range nodeWeights2Blocks { + for nw, lastAccepted := range nodeWeightsToBlocks { if sd.processing(lastAccepted) { newProcessingWeight, err := safemath.Add(processingWeight, nw.Weight) if err != nil { @@ -203,14 +204,14 @@ func (sd *stragglerDetector) getNetworkSnapshot() (snapshot, bool) { ratio := float64(totalWeightWeKnowItsLastAcceptedBlock) / float64(totalValidatorWeight) - if ratio < 0.8 { + if ratio < knownStakeThresholdRequiredForAnalysis { sd.log.Trace("Most stake we're connected to has the same height as we do", zap.Float64("ratio", ratio)) return snapshot{}, false } snap := snapshot{ - nodeWeights2Blocks: nodeWeightToLastAccepted, + nodeWeightsToBlocks: nodeWeightToLastAccepted, totalValidatorWeight: totalValidatorWeight, } @@ -221,7 +222,7 @@ func (sd *stragglerDetector) getNetworkSnapshot() (snapshot, bool) { return snapshot{}, false } -func (sd *stragglerDetector) getNetworkInfo() (nodeWeights2Blocks, uint64) { +func (sd *stragglerDetector) getNetworkInfo() (nodeWeightsToBlocks, uint64) { ratio := sd.connectedPercent() if ratio < sd.minConfirmationThreshold { // We don't know for sure whether we're behind or not. @@ -233,14 +234,14 @@ func (sd *stragglerDetector) getNetworkInfo() (nodeWeights2Blocks, uint64) { validators := nodeWeights(sd.connectedValidators().List()) - nodeWeight2lastAccepted := make(nodeWeights2Blocks, len(validators)) + nodeWeightTolastAccepted := make(nodeWeightsToBlocks, len(validators)) for _, vdr := range validators { lastAccepted, ok := sd.lastAcceptedByNodeID(vdr.Node) if !ok { continue } - nodeWeight2lastAccepted[vdr] = lastAccepted + nodeWeightTolastAccepted[vdr] = lastAccepted } totalValidatorWeight, err := validators.totalWeight() @@ -249,7 +250,7 @@ func (sd *stragglerDetector) getNetworkInfo() (nodeWeights2Blocks, uint64) { return nil, 0 } - totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeight2lastAccepted.totalWeight() + totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeightTolastAccepted.totalWeight() if err != nil { sd.log.Error("Failed computing total weight", zap.Error(err)) return nil, 0 @@ -268,22 +269,22 @@ func (sd *stragglerDetector) getNetworkInfo() (nodeWeights2Blocks, uint64) { zap.Float64("ratio", ratio)) return nil, 0 } - return nodeWeight2lastAccepted, totalValidatorWeight + return nodeWeightTolastAccepted, totalValidatorWeight } type snapshot struct { totalValidatorWeight uint64 - nodeWeights2Blocks nodeWeights2Blocks + nodeWeightsToBlocks nodeWeightsToBlocks } func (s snapshot) isEmpty() bool { - return s.totalValidatorWeight == 0 || len(s.nodeWeights2Blocks) == 0 + return s.totalValidatorWeight == 0 || len(s.nodeWeightsToBlocks) == 0 } -type nodeWeights2Blocks map[ids.NodeWeight]ids.ID +type nodeWeightsToBlocks map[ids.NodeWeight]ids.ID -func (nw2b nodeWeights2Blocks) totalWeight() (uint64, error) { - return nodeWeights(maps.Keys(nw2b)).totalWeight() +func (nwb nodeWeightsToBlocks) totalWeight() (uint64, error) { + return nodeWeights(maps.Keys(nwb)).totalWeight() } func dropHeight(f func() (ids.ID, uint64)) func() ids.ID { diff --git a/snow/engine/snowman/straggler_detect_test.go b/snow/engine/snowman/straggler_detect_test.go index 238fed062036..0e85af7c0de0 100644 --- a/snow/engine/snowman/straggler_detect_test.go +++ b/snow/engine/snowman/straggler_detect_test.go @@ -40,7 +40,7 @@ func TestNodeWeightsOverflow(t *testing.T) { } func TestNodeWeights2Blocks(t *testing.T) { - nw2b := nodeWeights2Blocks{ + nw2b := nodeWeightsToBlocks{ ids.NodeWeight{Weight: 5}: ids.Empty, ids.NodeWeight{Weight: 10}: ids.Empty, } @@ -115,7 +115,7 @@ func TestGetNetworkSnapshot(t *testing.T) { }, processing: map[ids.ID]struct{}{{0x1}: {}}, lastAccepted: ids.ID{0x0}, - expectedSnapshot: snapshot{totalValidatorWeight: 999999, nodeWeights2Blocks: nodeWeights2Blocks{ + expectedSnapshot: snapshot{totalValidatorWeight: 999999, nodeWeightsToBlocks: nodeWeightsToBlocks{ ids.NodeWeight{Node: n1, Weight: 999999}: {0x1}, }}, expectedOK: true, @@ -166,7 +166,7 @@ func TestFailedCatchingUp(t *testing.T) { { description: "stake overflow", input: snapshot{ - nodeWeights2Blocks: nodeWeights2Blocks{ + nodeWeightsToBlocks: nodeWeightsToBlocks{ ids.NodeWeight{Node: n1, Weight: math.MaxUint64 - 10}: ids.ID{0x1}, ids.NodeWeight{Node: n2, Weight: 11}: ids.ID{0x2}, }, @@ -180,7 +180,7 @@ func TestFailedCatchingUp(t *testing.T) { { description: "Straggling behind stake minority", input: snapshot{ - totalValidatorWeight: 100, nodeWeights2Blocks: nodeWeights2Blocks{ + totalValidatorWeight: 100, nodeWeightsToBlocks: nodeWeightsToBlocks{ ids.NodeWeight{Node: n1, Weight: 25}: ids.ID{0x1}, ids.NodeWeight{Node: n2, Weight: 50}: ids.ID{0x2}, }, @@ -194,7 +194,7 @@ func TestFailedCatchingUp(t *testing.T) { { description: "Straggling behind stake majority", input: snapshot{ - totalValidatorWeight: 100, nodeWeights2Blocks: nodeWeights2Blocks{ + totalValidatorWeight: 100, nodeWeightsToBlocks: nodeWeightsToBlocks{ ids.NodeWeight{Node: n1, Weight: 26}: ids.ID{0x1}, ids.NodeWeight{Node: n2, Weight: 50}: ids.ID{0x2}, }, @@ -209,7 +209,7 @@ func TestFailedCatchingUp(t *testing.T) { { description: "In sync with the majority", input: snapshot{ - totalValidatorWeight: 100, nodeWeights2Blocks: nodeWeights2Blocks{ + totalValidatorWeight: 100, nodeWeightsToBlocks: nodeWeightsToBlocks{ ids.NodeWeight{Node: n1, Weight: 75}: ids.ID{0x1}, ids.NodeWeight{Node: n2, Weight: 25}: ids.ID{0x2}, }, @@ -263,7 +263,7 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { } nonEmptySnap := snapshot{ totalValidatorWeight: 100, - nodeWeights2Blocks: nodeWeights2Blocks{ + nodeWeightsToBlocks: nodeWeightsToBlocks{ ids.NodeWeight{Weight: 100}: ids.Empty, }, } From 0228a655d49018d575f3bbcfbc24676ecf392f37 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Thu, 26 Sep 2024 22:00:56 +0200 Subject: [PATCH 4/6] Address code review comments III Signed-off-by: Yacov Manevich --- snow/engine/snowman/straggler_detect.go | 118 +++++++++++-------- snow/engine/snowman/straggler_detect_test.go | 10 ++ 2 files changed, 80 insertions(+), 48 deletions(-) diff --git a/snow/engine/snowman/straggler_detect.go b/snow/engine/snowman/straggler_detect.go index 6f4a8f026aff..699cd5f21027 100644 --- a/snow/engine/snowman/straggler_detect.go +++ b/snow/engine/snowman/straggler_detect.go @@ -175,44 +175,54 @@ func (sd *stragglerDetector) failedCatchingUp(s snapshot) bool { return false } -func (sd *stragglerDetector) getNetworkSnapshot() (snapshot, bool) { - nodeWeightToLastAccepted, totalValidatorWeight := sd.getNetworkInfo() - if len(nodeWeightToLastAccepted) == 0 { - return snapshot{}, false +func (sd *stragglerDetector) validateNetInfo(netInfo netInfo) bool { + if netInfo.connStakePercent < sd.minConfirmationThreshold { + // We don't know for sure whether we're behind or not. + // Even if we're behind, it's pointless to act before we have established + // connectivity with enough validators. + sd.log.Verbo("not enough connected stake to determine network info", zap.Float64("ratio", netInfo.connStakePercent)) + return false } - ourLastAcceptedBlock := sd.lastAccepted() + if netInfo.totalValidatorWeight == 0 { + sd.log.Trace("Connected to zero weight") + return false + } - prevLastAcceptedCount := len(nodeWeightToLastAccepted) - for k, v := range nodeWeightToLastAccepted { - if ourLastAcceptedBlock.Compare(v) == 0 { - delete(nodeWeightToLastAccepted, k) - } + totalKnownLastBlockStakePercent := float64(netInfo.totalWeightWeKnowItsLastAcceptedBlock) / float64(netInfo.totalValidatorWeight) + stakeAheadOfUs := float64(netInfo.totalPendingStake) / float64(netInfo.totalValidatorWeight) + + // Ensure we have collected last accepted blocks for at least 80% (or so) stake of the total weight we are connected to. + if totalKnownLastBlockStakePercent < minimumStakeThresholdRequiredForNetworkInfo { + sd.log.Trace("Not collected enough information about last accepted blocks for the validators we are connected to", + zap.Float64("ratio", totalKnownLastBlockStakePercent)) + return false } - newLastAcceptedCount := len(nodeWeightToLastAccepted) - sd.log.Trace("Excluding nodes with our own height", zap.Int("prev", prevLastAcceptedCount), zap.Int("new", newLastAcceptedCount)) + if stakeAheadOfUs < knownStakeThresholdRequiredForAnalysis { + sd.log.Trace("Most stake we're connected to has the same height as we do", + zap.Float64("ratio", stakeAheadOfUs)) + return false + } - // Ensure we have collected last accepted blocks that are not our own last accepted block - // for at least 80% stake of the total weight we are connected to. + return true +} - totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeightToLastAccepted.totalWeight() +func (sd *stragglerDetector) getNetworkSnapshot() (snapshot, bool) { + ourLastAcceptedBlock := sd.lastAccepted() + + netInfo, err := sd.getNetworkInfo(ourLastAcceptedBlock) if err != nil { - sd.log.Error("Failed computing total weight", zap.Error(err)) return snapshot{}, false } - ratio := float64(totalWeightWeKnowItsLastAcceptedBlock) / float64(totalValidatorWeight) - - if ratio < knownStakeThresholdRequiredForAnalysis { - sd.log.Trace("Most stake we're connected to has the same height as we do", - zap.Float64("ratio", ratio)) + if !sd.validateNetInfo(netInfo) { return snapshot{}, false } snap := snapshot{ - nodeWeightsToBlocks: nodeWeightToLastAccepted, - totalValidatorWeight: totalValidatorWeight, + nodeWeightsToBlocks: netInfo.nodeWeightToLastAccepted, + totalValidatorWeight: netInfo.totalValidatorWeight, } if sd.haveWeFailedCatchingUp(snap) { @@ -222,54 +232,66 @@ func (sd *stragglerDetector) getNetworkSnapshot() (snapshot, bool) { return snapshot{}, false } -func (sd *stragglerDetector) getNetworkInfo() (nodeWeightsToBlocks, uint64) { - ratio := sd.connectedPercent() - if ratio < sd.minConfirmationThreshold { - // We don't know for sure whether we're behind or not. - // Even if we're behind, it's pointless to act before we have established - // connectivity with enough validators. - sd.log.Verbo("not enough connected stake to determine network info", zap.Float64("ratio", ratio)) - return nil, 0 - } +type netInfo struct { + connStakePercent float64 + totalPendingStake uint64 + totalValidatorWeight uint64 + totalWeightWeKnowItsLastAcceptedBlock uint64 + nodeWeightToLastAccepted nodeWeightsToBlocks +} - validators := nodeWeights(sd.connectedValidators().List()) +func (sd *stragglerDetector) getNetworkInfo(ourLastAcceptedBlock ids.ID) (netInfo, error) { + var res netInfo + + res.connStakePercent = sd.connectedPercent() - nodeWeightTolastAccepted := make(nodeWeightsToBlocks, len(validators)) + validators := nodeWeights(sd.connectedValidators().List()) + nodeWeightToLastAccepted := make(nodeWeightsToBlocks, len(validators)) for _, vdr := range validators { lastAccepted, ok := sd.lastAcceptedByNodeID(vdr.Node) if !ok { continue } - nodeWeightTolastAccepted[vdr] = lastAccepted + nodeWeightToLastAccepted[vdr] = lastAccepted } totalValidatorWeight, err := validators.totalWeight() if err != nil { sd.log.Error("Failed computing total weight", zap.Error(err)) - return nil, 0 + return netInfo{}, err } + res.totalValidatorWeight = totalValidatorWeight - totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeightTolastAccepted.totalWeight() + totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeightToLastAccepted.totalWeight() if err != nil { sd.log.Error("Failed computing total weight", zap.Error(err)) - return nil, 0 + return netInfo{}, err } + res.totalWeightWeKnowItsLastAcceptedBlock = totalWeightWeKnowItsLastAcceptedBlock - if totalValidatorWeight == 0 { - sd.log.Trace("Connected to zero weight") - return nil, 0 + prevLastAcceptedCount := len(nodeWeightToLastAccepted) + + // Ensure we have collected last accepted blocks that are not our own last accepted block. + for nodeWeight, lastAccepted := range nodeWeightToLastAccepted { + if ourLastAcceptedBlock.Compare(lastAccepted) == 0 { + delete(nodeWeightToLastAccepted, nodeWeight) + } } - ratio = float64(totalWeightWeKnowItsLastAcceptedBlock) / float64(totalValidatorWeight) + res.nodeWeightToLastAccepted = nodeWeightToLastAccepted - // Ensure we have collected last accepted blocks for at least 80% (or so) stake of the total weight we are connected to. - if ratio < minimumStakeThresholdRequiredForNetworkInfo { - sd.log.Trace("Not collected enough information about last accepted blocks for the validators we are connected to", - zap.Float64("ratio", ratio)) - return nil, 0 + totalPendingStake, err := nodeWeightToLastAccepted.totalWeight() + if err != nil { + sd.log.Error("Failed computing total weight", zap.Error(err)) + return netInfo{}, err } - return nodeWeightTolastAccepted, totalValidatorWeight + + res.totalPendingStake = totalPendingStake + + sd.log.Trace("Excluding nodes with our own height", zap.Int("prev", prevLastAcceptedCount), zap.Uint64("new", totalPendingStake)) + + return res, nil } type snapshot struct { diff --git a/snow/engine/snowman/straggler_detect_test.go b/snow/engine/snowman/straggler_detect_test.go index 0e85af7c0de0..220975a46bc3 100644 --- a/snow/engine/snowman/straggler_detect_test.go +++ b/snow/engine/snowman/straggler_detect_test.go @@ -120,6 +120,16 @@ func TestGetNetworkSnapshot(t *testing.T) { }}, expectedOK: true, }, + { + description: "we're not behind", + connectedPercent: 1.0, + connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, Node: n1}}), + lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ + n1: {0x1}, + }, + processing: map[ids.ID]struct{}{{0x2}: {}}, + lastAccepted: ids.ID{0x0}, + }, } { t.Run(testCase.description, func(t *testing.T) { var buff logBuffer From eb68c752767edae7fc30a2dd51c5bb37f5777966 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Sat, 28 Sep 2024 02:14:15 +0200 Subject: [PATCH 5/6] Address code review comments IV Signed-off-by: Yacov Manevich --- ids/node_weight.go | 2 +- snow/engine/common/tracker/peers.go | 4 +- snow/engine/common/tracker/peers_test.go | 8 +- snow/engine/snowman/engine_decorator.go | 27 ++- snow/engine/snowman/engine_decorator_test.go | 27 +-- snow/engine/snowman/straggler_detect.go | 238 +++++++++---------- snow/engine/snowman/straggler_detect_test.go | 132 +++++----- snow/networking/handler/health.go | 2 +- 8 files changed, 209 insertions(+), 231 deletions(-) diff --git a/ids/node_weight.go b/ids/node_weight.go index 21309586ca2a..07d252015462 100644 --- a/ids/node_weight.go +++ b/ids/node_weight.go @@ -4,6 +4,6 @@ package ids type NodeWeight struct { - Node NodeID + ID NodeID Weight uint64 } diff --git a/snow/engine/common/tracker/peers.go b/snow/engine/common/tracker/peers.go index 65dda6f7d1ff..0dc7100c2c12 100644 --- a/snow/engine/common/tracker/peers.go +++ b/snow/engine/common/tracker/peers.go @@ -274,7 +274,7 @@ func (p *peerData) SampleValidator() (ids.NodeID, bool) { func (p *peerData) GetValidators() set.Set[ids.NodeWeight] { res := set.NewSet[ids.NodeWeight](len(p.validators)) for k, v := range p.validators { - res.Add(ids.NodeWeight{Node: k, Weight: v}) + res.Add(ids.NodeWeight{ID: k, Weight: v}) } return res } @@ -285,7 +285,7 @@ func (p *peerData) ConnectedValidators() set.Set[ids.NodeWeight] { copied := set.NewSet[ids.NodeWeight](len(p.connectedValidators)) for _, vdrID := range p.connectedValidators.List() { weight := p.validators[vdrID] - copied.Add(ids.NodeWeight{Node: vdrID, Weight: weight}) + copied.Add(ids.NodeWeight{ID: vdrID, Weight: weight}) } return copied } diff --git a/snow/engine/common/tracker/peers_test.go b/snow/engine/common/tracker/peers_test.go index ac577a399b48..24d1ca1dc725 100644 --- a/snow/engine/common/tracker/peers_test.go +++ b/snow/engine/common/tracker/peers_test.go @@ -58,10 +58,10 @@ func TestConnectedValidators(t *testing.T) { require.NoError(p.Connected(context.Background(), nodeID2, version.CurrentApp)) require.Equal(uint64(11), p.ConnectedWeight()) - require.True(set.Of(ids.NodeWeight{Node: nodeID1, Weight: 5}, ids.NodeWeight{Node: nodeID2, Weight: 6}).Equals(p.GetValidators())) - require.True(set.Of(ids.NodeWeight{Node: nodeID1, Weight: 5}, ids.NodeWeight{Node: nodeID2, Weight: 6}).Equals(p.ConnectedValidators())) + require.True(set.Of(ids.NodeWeight{ID: nodeID1, Weight: 5}, ids.NodeWeight{ID: nodeID2, Weight: 6}).Equals(p.GetValidators())) + require.True(set.Of(ids.NodeWeight{ID: nodeID1, Weight: 5}, ids.NodeWeight{ID: nodeID2, Weight: 6}).Equals(p.ConnectedValidators())) require.NoError(p.Disconnected(context.Background(), nodeID2)) - require.True(set.Of(ids.NodeWeight{Node: nodeID1, Weight: 5}, ids.NodeWeight{Node: nodeID2, Weight: 6}).Equals(p.GetValidators())) - require.True(set.Of(ids.NodeWeight{Node: nodeID1, Weight: 5}).Equals(p.ConnectedValidators())) + require.True(set.Of(ids.NodeWeight{ID: nodeID1, Weight: 5}, ids.NodeWeight{ID: nodeID2, Weight: 6}).Equals(p.GetValidators())) + require.True(set.Of(ids.NodeWeight{ID: nodeID1, Weight: 5}).Equals(p.ConnectedValidators())) } diff --git a/snow/engine/snowman/engine_decorator.go b/snow/engine/snowman/engine_decorator.go index abeb67d07612..41102695dc79 100644 --- a/snow/engine/snowman/engine_decorator.go +++ b/snow/engine/snowman/engine_decorator.go @@ -21,9 +21,30 @@ type decoratedEngineWithStragglerDetector struct { func NewDecoratedEngineWithStragglerDetector(e *Engine, time func() time.Time, f func(time.Duration)) common.Engine { minConfRatio := float64(e.Params.AlphaConfidence) / float64(e.Params.K) - sd := newStragglerDetector(time, e.Config.Ctx.Log, minConfRatio, e.Consensus.LastAccepted, - e.Config.ConnectedValidators.ConnectedValidators, e.Config.ConnectedValidators.ConnectedPercent, - e.Consensus.Processing, e.acceptedFrontiers.LastAccepted) + + sa := &snapshotAnalyzer{ + log: e.Config.Ctx.Log, + processing: e.Consensus.Processing, + } + + s := &snapshotter{ + log: e.Config.Ctx.Log, + connectedValidators: e.Config.ConnectedValidators.ConnectedValidators, + minConfirmationThreshold: minConfRatio, + lastAcceptedByNodeID: e.acceptedFrontiers.LastAccepted, + lastAccepted: dropHeight(e.Consensus.LastAccepted), + } + + conf := stragglerDetectorConfig{ + getSnapshot: s.getNetworkSnapshot, + areWeBehindTheRest: sa.areWeBehindTheRest, + minStragglerCheckInterval: minStragglerCheckInterval, + log: e.Config.Ctx.Log, + getTime: time, + } + + sd := newStragglerDetector(conf) + return &decoratedEngineWithStragglerDetector{ Engine: e, f: f, diff --git a/snow/engine/snowman/engine_decorator_test.go b/snow/engine/snowman/engine_decorator_test.go index 367631249eaa..c40664c1e051 100644 --- a/snow/engine/snowman/engine_decorator_test.go +++ b/snow/engine/snowman/engine_decorator_test.go @@ -13,12 +13,13 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/consensus/snowman" "github.com/ava-labs/avalanchego/snow/consensus/snowman/snowmantest" + "github.com/ava-labs/avalanchego/utils/timer/mockable" ) func TestEngineStragglerDetector(t *testing.T) { require := require.New(t) - fakeClock := make(chan time.Time, 1) + var fakeClock mockable.Clock conf := DefaultConfig(t) peerID, _, sender, vm, engine := setup(t, conf) @@ -26,24 +27,14 @@ func TestEngineStragglerDetector(t *testing.T) { parent := snowmantest.BuildChild(snowmantest.Genesis) require.NoError(conf.Consensus.Add(parent)) - listenerShouldInvokeWith := []time.Duration{0, 0, time.Second * 2} - - fakeTime := func() time.Time { - select { - case now := <-fakeClock: - return now - default: - require.Fail("should have a time.Time in the channel") - return time.Time{} - } - } + listenerShouldInvokeWith := []time.Duration{0, 0, minStragglerCheckInterval * 2} f := func(duration time.Duration) { require.Equal(listenerShouldInvokeWith[0], duration) listenerShouldInvokeWith = listenerShouldInvokeWith[1:] } - decoratedEngine := NewDecoratedEngineWithStragglerDetector(engine, fakeTime, f) + decoratedEngine := NewDecoratedEngineWithStragglerDetector(engine, fakeClock.Time, f) vm.GetBlockF = func(_ context.Context, blkID ids.ID) (snowman.Block, error) { switch blkID { @@ -62,13 +53,13 @@ func TestEngineStragglerDetector(t *testing.T) { } now := time.Now() - fakeClock <- now + fakeClock.Set(now) require.NoError(decoratedEngine.Chits(context.Background(), peerID, 0, parent.ID(), parent.ID(), parent.ID())) - now = now.Add(time.Second * 2) - fakeClock <- now + now = now.Add(minStragglerCheckInterval * 2) + fakeClock.Set(now) require.NoError(decoratedEngine.Chits(context.Background(), peerID, 0, parent.ID(), parent.ID(), parent.ID())) - now = now.Add(time.Second * 2) - fakeClock <- now + now = now.Add(minStragglerCheckInterval * 2) + fakeClock.Set(now) require.NoError(decoratedEngine.Chits(context.Background(), peerID, 0, parent.ID(), parent.ID(), parent.ID())) require.Empty(listenerShouldInvokeWith) } diff --git a/snow/engine/snowman/straggler_detect.go b/snow/engine/snowman/straggler_detect.go index 699cd5f21027..5a8525239368 100644 --- a/snow/engine/snowman/straggler_detect.go +++ b/snow/engine/snowman/straggler_detect.go @@ -18,7 +18,7 @@ import ( ) const ( - minStragglerCheckInterval = time.Second + minStragglerCheckInterval = 10 * time.Second stakeThresholdForStragglerSuspicion = 0.75 minimumStakeThresholdRequiredForNetworkInfo = 0.8 knownStakeThresholdRequiredForAnalysis = 0.8 @@ -34,34 +34,13 @@ type stragglerDetectorConfig struct { // log logs events log logging.Logger - // minConfirmationThreshold is the minimum stake percentage that below it, we do not check if we are stragglers. - minConfirmationThreshold float64 - - // connectedPercent returns the stake percentage of connected nodes. - connectedPercent func() float64 - - // connectedValidators returns a set of tuples of NodeID and corresponding weight. - connectedValidators func() set.Set[ids.NodeWeight] - - // lastAcceptedByNodeID returns the last accepted height a node has reported, or false if it is unknown. - lastAcceptedByNodeID func(id ids.NodeID) (ids.ID, bool) - - // processing returns whether this block ID is known and its descendants have not yet been accepted by consensus. - // This means that when the last accepted block is given as input, true is returned, as by definition - // its descendants have not been accepted by consensus, but this block is known. - // For any block ID belonging to an ancestor of the last accepted block, false is returned, - // as the last accepted block has been accepted by consensus. - processing func(id ids.ID) bool - - // lastAccepted returns the last accepted block of this node. - lastAccepted func() ids.ID - // getSnapshot returns a snapshot of the network's nodes and their last accepted blocks, - // or false if it fails from some reason. + // excluding nodes that have the same last accepted block as we do. + // Returns false if it fails from some reason. getSnapshot func() (snapshot, bool) - // haveWeFailedCatchingUp returns whether we have not replicated enough blocks of the given snapshot - haveWeFailedCatchingUp func(snapshot) bool + // areWeBehindTheRest returns whether we have not replicated enough blocks of the given snapshot + areWeBehindTheRest func(snapshot) bool } type stragglerDetector struct { @@ -78,34 +57,10 @@ type stragglerDetector struct { prevSnapshot snapshot } -func newStragglerDetector( - getTime func() time.Time, - log logging.Logger, - minConfirmationThreshold float64, - lastAccepted func() (ids.ID, uint64), - connectedValidators func() set.Set[ids.NodeWeight], - connectedPercent func() float64, - processing func(id ids.ID) bool, - lastAcceptedByNodeID func(ids.NodeID) (ids.ID, bool), -) *stragglerDetector { - sd := &stragglerDetector{ - stragglerDetectorConfig: stragglerDetectorConfig{ - lastAccepted: dropHeight(lastAccepted), - processing: processing, - minStragglerCheckInterval: minStragglerCheckInterval, - log: log, - connectedValidators: connectedValidators, - connectedPercent: connectedPercent, - minConfirmationThreshold: minConfirmationThreshold, - lastAcceptedByNodeID: lastAcceptedByNodeID, - getTime: getTime, - }, +func newStragglerDetector(config stragglerDetectorConfig) *stragglerDetector { + return &stragglerDetector{ + stragglerDetectorConfig: config, } - - sd.getSnapshot = sd.getNetworkSnapshot - sd.haveWeFailedCatchingUp = sd.failedCatchingUp - - return sd } // CheckIfWeAreStragglingBehind returns for how long our ledger is behind the rest @@ -127,65 +82,92 @@ func (sd *stragglerDetector) CheckIfWeAreStragglingBehind() time.Duration { }() if sd.prevSnapshot.isEmpty() { - snapshot, ok := sd.getSnapshot() - if !ok { - sd.log.Trace("No node snapshot obtained") - sd.continuousStragglingPeriod = 0 - } - sd.prevSnapshot = snapshot + sd.obtainSnapshot() } else { - if sd.haveWeFailedCatchingUp(sd.prevSnapshot) { - timeSinceLastCheck := now.Sub(sd.previousStragglerCheckTime) - sd.continuousStragglingPeriod += timeSinceLastCheck - } else { - sd.continuousStragglingPeriod = 0 - } - sd.prevSnapshot = snapshot{} + sd.evaluateSnapshot(now) } return sd.continuousStragglingPeriod } -func (sd *stragglerDetector) failedCatchingUp(s snapshot) bool { - totalValidatorWeight, nodeWeightsToBlocks := s.totalValidatorWeight, s.nodeWeightsToBlocks - - var processingWeight uint64 - for nw, lastAccepted := range nodeWeightsToBlocks { - if sd.processing(lastAccepted) { - newProcessingWeight, err := safemath.Add(processingWeight, nw.Weight) - if err != nil { - sd.log.Error("Cumulative weight overflow", zap.Uint64("cumulative", processingWeight), zap.Uint64("added", nw.Weight)) - return false - } - processingWeight = newProcessingWeight - } +func (sd *stragglerDetector) obtainSnapshot() { + snap, ok := sd.getSnapshot() + sd.prevSnapshot = snap + if !ok || !sd.areWeBehindTheRest(snap) { + sd.log.Trace("No node snapshot obtained") + sd.continuousStragglingPeriod = 0 + sd.prevSnapshot = snapshot{} + } +} + +func (sd *stragglerDetector) evaluateSnapshot(now time.Time) { + if sd.areWeBehindTheRest(sd.prevSnapshot) { + timeSinceLastCheck := now.Sub(sd.previousStragglerCheckTime) + sd.continuousStragglingPeriod += timeSinceLastCheck + } else { + sd.continuousStragglingPeriod = 0 + } + sd.prevSnapshot = snapshot{} +} + +type snapshotAnalyzer struct { + log logging.Logger + + // processing returns whether this block ID is known and its descendants have not yet been accepted by consensus. + // This means that when the last accepted block is given as input, true is returned, as by definition + // its descendants have not been accepted by consensus, but this block is known. + // For any block ID belonging to an ancestor of the last accepted block, false is returned, + // as the last accepted block has been accepted by consensus. + processing func(id ids.ID) bool +} + +func (sa *snapshotAnalyzer) areWeBehindTheRest(s snapshot) bool { + if s.isEmpty() { + return false } - sd.log.Trace("Counted total weight that accepted blocks we're still processing", zap.Uint64("weight", processingWeight)) + totalValidatorWeight, nodeWeightsToBlocks := s.totalValidatorWeight, s.lastAcceptedBlockID + + processingWeight, err := nodeWeightsToBlocks.filter(sa.processing).totalWeight() + if err != nil { + sa.log.Error("Failed computing total weight", zap.Error(err)) + return false + } + + sa.log.Trace("Counted total weight that accepted blocks we're still processing", zap.Uint64("weight", processingWeight)) ratio := float64(processingWeight) / float64(totalValidatorWeight) if ratio > stakeThresholdForStragglerSuspicion { - sd.log.Trace("We are straggling behind", zap.Float64("ratio", ratio)) + sa.log.Trace("We are straggling behind", zap.Float64("ratio", ratio)) return true } - sd.log.Trace("Nodes ahead of us:", zap.Float64("ratio", ratio)) + sa.log.Trace("Nodes ahead of us:", zap.Float64("ratio", ratio)) return false } -func (sd *stragglerDetector) validateNetInfo(netInfo netInfo) bool { - if netInfo.connStakePercent < sd.minConfirmationThreshold { - // We don't know for sure whether we're behind or not. - // Even if we're behind, it's pointless to act before we have established - // connectivity with enough validators. - sd.log.Verbo("not enough connected stake to determine network info", zap.Float64("ratio", netInfo.connStakePercent)) - return false - } +type snapshotter struct { + // log logs events + log logging.Logger + + // minConfirmationThreshold is the minimum stake percentage that below it, we do not check if we are stragglers. + minConfirmationThreshold float64 + // lastAccepted returns the last accepted block of this node. + lastAccepted func() ids.ID + + // connectedValidators returns a set of tuples of NodeID and corresponding weight. + connectedValidators func() set.Set[ids.NodeWeight] + + // lastAcceptedByNodeID returns the last accepted height a node has reported, or false if it is unknown. + lastAcceptedByNodeID func(id ids.NodeID) (ids.ID, bool) +} + +func (s *snapshotter) validateNetInfo(netInfo netInfo) bool { if netInfo.totalValidatorWeight == 0 { - sd.log.Trace("Connected to zero weight") + s.log.Trace("Connected to zero weight") return false } @@ -194,13 +176,13 @@ func (sd *stragglerDetector) validateNetInfo(netInfo netInfo) bool { // Ensure we have collected last accepted blocks for at least 80% (or so) stake of the total weight we are connected to. if totalKnownLastBlockStakePercent < minimumStakeThresholdRequiredForNetworkInfo { - sd.log.Trace("Not collected enough information about last accepted blocks for the validators we are connected to", + s.log.Trace("Not collected enough information about last accepted blocks for the validators we are connected to", zap.Float64("ratio", totalKnownLastBlockStakePercent)) return false } if stakeAheadOfUs < knownStakeThresholdRequiredForAnalysis { - sd.log.Trace("Most stake we're connected to has the same height as we do", + s.log.Trace("Most stake we're connected to has the same height as we do", zap.Float64("ratio", stakeAheadOfUs)) return false } @@ -208,48 +190,34 @@ func (sd *stragglerDetector) validateNetInfo(netInfo netInfo) bool { return true } -func (sd *stragglerDetector) getNetworkSnapshot() (snapshot, bool) { - ourLastAcceptedBlock := sd.lastAccepted() +func (s *snapshotter) getNetworkSnapshot() (snapshot, bool) { + ourLastAcceptedBlock := s.lastAccepted() - netInfo, err := sd.getNetworkInfo(ourLastAcceptedBlock) + netInfo, err := s.getNetworkInfo(ourLastAcceptedBlock) if err != nil { return snapshot{}, false } - if !sd.validateNetInfo(netInfo) { + if !s.validateNetInfo(netInfo) { return snapshot{}, false } snap := snapshot{ - nodeWeightsToBlocks: netInfo.nodeWeightToLastAccepted, + lastAcceptedBlockID: netInfo.nodeWeightToLastAccepted, totalValidatorWeight: netInfo.totalValidatorWeight, } - if sd.haveWeFailedCatchingUp(snap) { - return snap, true - } - - return snapshot{}, false + return snap, true } -type netInfo struct { - connStakePercent float64 - totalPendingStake uint64 - totalValidatorWeight uint64 - totalWeightWeKnowItsLastAcceptedBlock uint64 - nodeWeightToLastAccepted nodeWeightsToBlocks -} - -func (sd *stragglerDetector) getNetworkInfo(ourLastAcceptedBlock ids.ID) (netInfo, error) { +func (s *snapshotter) getNetworkInfo(ourLastAcceptedBlock ids.ID) (netInfo, error) { var res netInfo - res.connStakePercent = sd.connectedPercent() - - validators := nodeWeights(sd.connectedValidators().List()) + validators := nodeWeights(s.connectedValidators().List()) nodeWeightToLastAccepted := make(nodeWeightsToBlocks, len(validators)) for _, vdr := range validators { - lastAccepted, ok := sd.lastAcceptedByNodeID(vdr.Node) + lastAccepted, ok := s.lastAcceptedByNodeID(vdr.ID) if !ok { continue } @@ -258,14 +226,14 @@ func (sd *stragglerDetector) getNetworkInfo(ourLastAcceptedBlock ids.ID) (netInf totalValidatorWeight, err := validators.totalWeight() if err != nil { - sd.log.Error("Failed computing total weight", zap.Error(err)) + s.log.Error("Failed computing total weight", zap.Error(err)) return netInfo{}, err } res.totalValidatorWeight = totalValidatorWeight totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeightToLastAccepted.totalWeight() if err != nil { - sd.log.Error("Failed computing total weight", zap.Error(err)) + s.log.Error("Failed computing total weight", zap.Error(err)) return netInfo{}, err } res.totalWeightWeKnowItsLastAcceptedBlock = totalWeightWeKnowItsLastAcceptedBlock @@ -273,34 +241,39 @@ func (sd *stragglerDetector) getNetworkInfo(ourLastAcceptedBlock ids.ID) (netInf prevLastAcceptedCount := len(nodeWeightToLastAccepted) // Ensure we have collected last accepted blocks that are not our own last accepted block. - for nodeWeight, lastAccepted := range nodeWeightToLastAccepted { - if ourLastAcceptedBlock.Compare(lastAccepted) == 0 { - delete(nodeWeightToLastAccepted, nodeWeight) - } - } + nodeWeightToLastAccepted = nodeWeightToLastAccepted.filter(func(id ids.ID) bool { + return ourLastAcceptedBlock.Compare(id) != 0 + }) res.nodeWeightToLastAccepted = nodeWeightToLastAccepted totalPendingStake, err := nodeWeightToLastAccepted.totalWeight() if err != nil { - sd.log.Error("Failed computing total weight", zap.Error(err)) + s.log.Error("Failed computing total weight", zap.Error(err)) return netInfo{}, err } res.totalPendingStake = totalPendingStake - sd.log.Trace("Excluding nodes with our own height", zap.Int("prev", prevLastAcceptedCount), zap.Uint64("new", totalPendingStake)) + s.log.Trace("Excluding nodes with our own height", zap.Int("prev", prevLastAcceptedCount), zap.Uint64("new", totalPendingStake)) return res, nil } +type netInfo struct { + totalPendingStake uint64 + totalValidatorWeight uint64 + totalWeightWeKnowItsLastAcceptedBlock uint64 + nodeWeightToLastAccepted nodeWeightsToBlocks +} + type snapshot struct { totalValidatorWeight uint64 - nodeWeightsToBlocks nodeWeightsToBlocks + lastAcceptedBlockID nodeWeightsToBlocks } func (s snapshot) isEmpty() bool { - return s.totalValidatorWeight == 0 || len(s.nodeWeightsToBlocks) == 0 + return s.totalValidatorWeight == 0 || len(s.lastAcceptedBlockID) == 0 } type nodeWeightsToBlocks map[ids.NodeWeight]ids.ID @@ -309,6 +282,7 @@ func (nwb nodeWeightsToBlocks) totalWeight() (uint64, error) { return nodeWeights(maps.Keys(nwb)).totalWeight() } +// dropHeight removes the second return parameter from the function f() and keeps its first return parameter, ids.ID. func dropHeight(f func() (ids.ID, uint64)) func() ids.ID { return func() ids.ID { id, _ := f() @@ -329,3 +303,13 @@ func (nws nodeWeights) totalWeight() (uint64, error) { } return weight, nil } + +func (nwb nodeWeightsToBlocks) filter(f func(ids.ID) bool) nodeWeightsToBlocks { + filtered := make(nodeWeightsToBlocks, len(nwb)) + for nw, id := range nwb { + if f(id) { + filtered[nw] = id + } + } + return filtered +} diff --git a/snow/engine/snowman/straggler_detect_test.go b/snow/engine/snowman/straggler_detect_test.go index 220975a46bc3..63d072ae770d 100644 --- a/snow/engine/snowman/straggler_detect_test.go +++ b/snow/engine/snowman/straggler_detect_test.go @@ -4,6 +4,7 @@ package snowman import ( + "fmt" "math" "testing" "time" @@ -13,6 +14,7 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/utils/logging" "github.com/ava-labs/avalanchego/utils/set" + "github.com/ava-labs/avalanchego/utils/timer/mockable" safemath "github.com/ava-labs/avalanchego/utils/math" ) @@ -71,26 +73,18 @@ func TestGetNetworkSnapshot(t *testing.T) { lastAcceptedFromNodes map[ids.NodeID]ids.ID processing map[ids.ID]struct{} connectedValidators func() set.Set[ids.NodeWeight] - connectedPercent float64 expectedSnapshot snapshot expectedOK bool expectedLogged string }{ - { - description: "not enough connected validators", - connectedValidators: connectedValidators([]ids.NodeWeight{}), - expectedLogged: "not enough connected stake to determine network info", - }, { description: "connected to zero weight", - connectedPercent: 1.0, connectedValidators: connectedValidators([]ids.NodeWeight{}), expectedLogged: "Connected to zero weight", }, { description: "not enough info", - connectedPercent: 1.0, - connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 1, Node: n1}, {Weight: 999999, Node: n2}}), + connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 1, ID: n1}, {Weight: 999999, ID: n2}}), lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ n1: {0x1}, }, @@ -98,8 +92,7 @@ func TestGetNetworkSnapshot(t *testing.T) { }, { description: "we're in sync", - connectedPercent: 1.0, - connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, Node: n1}}), + connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, ID: n1}}), lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ n1: {0x1}, }, @@ -108,48 +101,49 @@ func TestGetNetworkSnapshot(t *testing.T) { }, { description: "we're behind", - connectedPercent: 1.0, - connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, Node: n1}}), + connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, ID: n1}}), lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ n1: {0x1}, }, processing: map[ids.ID]struct{}{{0x1}: {}}, lastAccepted: ids.ID{0x0}, - expectedSnapshot: snapshot{totalValidatorWeight: 999999, nodeWeightsToBlocks: nodeWeightsToBlocks{ - ids.NodeWeight{Node: n1, Weight: 999999}: {0x1}, + expectedSnapshot: snapshot{totalValidatorWeight: 999999, lastAcceptedBlockID: nodeWeightsToBlocks{ + ids.NodeWeight{ID: n1, Weight: 999999}: {0x1}, }}, expectedOK: true, }, { description: "we're not behind", - connectedPercent: 1.0, - connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, Node: n1}}), + connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, ID: n1}}), lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ n1: {0x1}, }, processing: map[ids.ID]struct{}{{0x2}: {}}, lastAccepted: ids.ID{0x0}, + expectedSnapshot: snapshot{totalValidatorWeight: 999999, lastAcceptedBlockID: nodeWeightsToBlocks{ + ids.NodeWeight{ID: n1, Weight: 999999}: {0x1}, + }}, + expectedOK: true, }, } { t.Run(testCase.description, func(t *testing.T) { var buff logBuffer log := logging.NewLogger("", logging.NewWrappedCore(logging.Verbo, &buff, logging.Plain.ConsoleEncoder())) - sd := newStragglerDetector(nil, log, 0.75, - func() (ids.ID, uint64) { - return testCase.lastAccepted, 0 + s := &snapshotter{ + log: log, + connectedValidators: testCase.connectedValidators, + minConfirmationThreshold: 0.75, + lastAccepted: func() ids.ID { + return testCase.lastAccepted }, - testCase.connectedValidators, func() float64 { return testCase.connectedPercent }, - func(id ids.ID) bool { - _, ok := testCase.processing[id] - return ok - }, - func(vdr ids.NodeID) (ids.ID, bool) { + lastAcceptedByNodeID: func(vdr ids.NodeID) (ids.ID, bool) { id, ok := testCase.lastAcceptedFromNodes[vdr] return id, ok - }) + }, + } - snapshot, ok := sd.getNetworkSnapshot() + snapshot, ok := s.getNetworkSnapshot() require.Equal(t, testCase.expectedSnapshot, snapshot) require.Equal(t, testCase.expectedOK, ok) require.Contains(t, buff.String(), testCase.expectedLogged) @@ -168,7 +162,6 @@ func TestFailedCatchingUp(t *testing.T) { lastAcceptedFromNodes map[ids.NodeID]ids.ID processing map[ids.ID]struct{} connectedValidators []ids.NodeWeight - connectedPercent float64 input snapshot expected bool expectedLogged string @@ -176,23 +169,24 @@ func TestFailedCatchingUp(t *testing.T) { { description: "stake overflow", input: snapshot{ - nodeWeightsToBlocks: nodeWeightsToBlocks{ - ids.NodeWeight{Node: n1, Weight: math.MaxUint64 - 10}: ids.ID{0x1}, - ids.NodeWeight{Node: n2, Weight: 11}: ids.ID{0x2}, + totalValidatorWeight: 100, + lastAcceptedBlockID: nodeWeightsToBlocks{ + ids.NodeWeight{ID: n1, Weight: math.MaxUint64 - 10}: ids.ID{0x1}, + ids.NodeWeight{ID: n2, Weight: 11}: ids.ID{0x2}, }, }, processing: map[ids.ID]struct{}{ {0x1}: {}, {0x2}: {}, }, - expectedLogged: "Cumulative weight overflow", + expectedLogged: "Failed computing total weight", }, { description: "Straggling behind stake minority", input: snapshot{ - totalValidatorWeight: 100, nodeWeightsToBlocks: nodeWeightsToBlocks{ - ids.NodeWeight{Node: n1, Weight: 25}: ids.ID{0x1}, - ids.NodeWeight{Node: n2, Weight: 50}: ids.ID{0x2}, + totalValidatorWeight: 100, lastAcceptedBlockID: nodeWeightsToBlocks{ + ids.NodeWeight{ID: n1, Weight: 25}: ids.ID{0x1}, + ids.NodeWeight{ID: n2, Weight: 50}: ids.ID{0x2}, }, }, processing: map[ids.ID]struct{}{ @@ -204,9 +198,9 @@ func TestFailedCatchingUp(t *testing.T) { { description: "Straggling behind stake majority", input: snapshot{ - totalValidatorWeight: 100, nodeWeightsToBlocks: nodeWeightsToBlocks{ - ids.NodeWeight{Node: n1, Weight: 26}: ids.ID{0x1}, - ids.NodeWeight{Node: n2, Weight: 50}: ids.ID{0x2}, + totalValidatorWeight: 100, lastAcceptedBlockID: nodeWeightsToBlocks{ + ids.NodeWeight{ID: n1, Weight: 26}: ids.ID{0x1}, + ids.NodeWeight{ID: n2, Weight: 50}: ids.ID{0x2}, }, }, processing: map[ids.ID]struct{}{ @@ -219,9 +213,9 @@ func TestFailedCatchingUp(t *testing.T) { { description: "In sync with the majority", input: snapshot{ - totalValidatorWeight: 100, nodeWeightsToBlocks: nodeWeightsToBlocks{ - ids.NodeWeight{Node: n1, Weight: 75}: ids.ID{0x1}, - ids.NodeWeight{Node: n2, Weight: 25}: ids.ID{0x2}, + totalValidatorWeight: 100, lastAcceptedBlockID: nodeWeightsToBlocks{ + ids.NodeWeight{ID: n1, Weight: 75}: ids.ID{0x1}, + ids.NodeWeight{ID: n2, Weight: 25}: ids.ID{0x2}, }, }, processing: map[ids.ID]struct{}{ @@ -234,34 +228,22 @@ func TestFailedCatchingUp(t *testing.T) { var buff logBuffer log := logging.NewLogger("", logging.NewWrappedCore(logging.Verbo, &buff, logging.Plain.ConsoleEncoder())) - sd := newStragglerDetector(nil, log, 0.75, - func() (ids.ID, uint64) { - return testCase.lastAccepted, 0 - }, - func() set.Set[ids.NodeWeight] { - var set set.Set[ids.NodeWeight] - for _, nw := range testCase.connectedValidators { - set.Add(nw) - } - return set - }, func() float64 { return testCase.connectedPercent }, - func(id ids.ID) bool { + sa := &snapshotAnalyzer{ + log: log, + processing: func(id ids.ID) bool { _, ok := testCase.processing[id] return ok }, - func(vdr ids.NodeID) (ids.ID, bool) { - id, ok := testCase.lastAcceptedFromNodes[vdr] - return id, ok - }) + } - require.Equal(t, testCase.expected, sd.failedCatchingUp(testCase.input)) + require.Equal(t, testCase.expected, sa.areWeBehindTheRest(testCase.input)) require.Contains(t, buff.String(), testCase.expectedLogged) }) } } func TestCheckIfWeAreStragglingBehind(t *testing.T) { - fakeClock := make(chan time.Time, 1) + var fakeClock mockable.Clock snapshots := make(chan snapshot, 1) assertNoSnapshotsRemain := func() { @@ -273,7 +255,7 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { } nonEmptySnap := snapshot{ totalValidatorWeight: 100, - nodeWeightsToBlocks: nodeWeightsToBlocks{ + lastAcceptedBlockID: nodeWeightsToBlocks{ ids.NodeWeight{Weight: 100}: ids.Empty, }, } @@ -286,16 +268,13 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { sd := stragglerDetector{ stragglerDetectorConfig: stragglerDetectorConfig{ minStragglerCheckInterval: time.Second, - getTime: func() time.Time { - now := <-fakeClock - return now - }, - log: log, + getTime: fakeClock.Time, + log: log, getSnapshot: func() (snapshot, bool) { s := <-snapshots return s, !s.isEmpty() }, - haveWeFailedCatchingUp: func(_ snapshot) bool { + areWeBehindTheRest: func(_ snapshot) bool { return haveWeFailedCatchingUpReturns }, }, @@ -329,9 +308,10 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { }, }, { - description: "Advance time some more to the first check where the snapshot isn't empty", - timeAdvanced: time.Second * 2, - snapshotsRead: []snapshot{nonEmptySnap}, + description: "Advance time some more to the first check where the snapshot isn't empty", + timeAdvanced: time.Second * 2, + snapshotsRead: []snapshot{nonEmptySnap}, + haveWeFailedCatchingUpReturns: true, evalExtraAssertions: func(t *testing.T) { require.Empty(t, buff.String()) }, @@ -346,11 +326,12 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { }, }, { - description: "The third snapshot is due to a fresh check", - timeAdvanced: time.Second * 2, - snapshotsRead: []snapshot{nonEmptySnap}, + description: "The third snapshot is due to a fresh check", + timeAdvanced: time.Second * 2, + snapshotsRead: []snapshot{nonEmptySnap}, + haveWeFailedCatchingUpReturns: true, // We carry over the total straggling time from previous testCase to this check, - // as we need the next check to nullify it. + // as we expect the next check to nullify it. expectedStragglingTime: time.Second * 2, evalExtraAssertions: func(_ *testing.T) {}, }, @@ -361,8 +342,9 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { }, } { t.Run(testCase.description, func(t *testing.T) { + fmt.Println(testCase.description) fakeTime = fakeTime.Add(testCase.timeAdvanced) - fakeClock <- fakeTime + fakeClock.Set(fakeTime) // Load the snapshot expected to be retrieved in this testCase, if applicable. if len(testCase.snapshotsRead) > 0 { diff --git a/snow/networking/handler/health.go b/snow/networking/handler/health.go index 4c43e0ead003..25c1463e4f78 100644 --- a/snow/networking/handler/health.go +++ b/snow/networking/handler/health.go @@ -72,7 +72,7 @@ func (h *handler) getDisconnectedValidators() set.Set[ids.NodeID] { func withoutWeights(weights set.Set[ids.NodeWeight]) set.Set[ids.NodeID] { var res set.Set[ids.NodeID] for _, nw := range weights.List() { - res.Add(nw.Node) + res.Add(nw.ID) } return res } From 438768de7c78f5ac3d2acefed5314039751878e7 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Wed, 16 Oct 2024 18:01:48 +0200 Subject: [PATCH 6/6] Use last accepted height from Chits message to detect if we're behind Signed-off-by: Yacov Manevich --- snow/engine/snowman/engine_decorator.go | 22 ++- snow/engine/snowman/engine_decorator_test.go | 6 +- snow/engine/snowman/straggler_detect.go | 165 +++++++------------ snow/engine/snowman/straggler_detect_test.go | 160 ++++++++---------- 4 files changed, 143 insertions(+), 210 deletions(-) diff --git a/snow/engine/snowman/engine_decorator.go b/snow/engine/snowman/engine_decorator.go index 41102695dc79..11644ea80ae9 100644 --- a/snow/engine/snowman/engine_decorator.go +++ b/snow/engine/snowman/engine_decorator.go @@ -22,17 +22,21 @@ type decoratedEngineWithStragglerDetector struct { func NewDecoratedEngineWithStragglerDetector(e *Engine, time func() time.Time, f func(time.Duration)) common.Engine { minConfRatio := float64(e.Params.AlphaConfidence) / float64(e.Params.K) + subnet := e.Ctx.SubnetID + sa := &snapshotAnalyzer{ - log: e.Config.Ctx.Log, - processing: e.Consensus.Processing, + lastAcceptedHeight: onlyHeight(e.Consensus.LastAccepted), + log: e.Config.Ctx.Log, } s := &snapshotter{ - log: e.Config.Ctx.Log, - connectedValidators: e.Config.ConnectedValidators.ConnectedValidators, - minConfirmationThreshold: minConfRatio, - lastAcceptedByNodeID: e.acceptedFrontiers.LastAccepted, - lastAccepted: dropHeight(e.Consensus.LastAccepted), + totalWeight: func() (uint64, error) { + return e.Validators.TotalWeight(subnet) + }, + log: e.Config.Ctx.Log, + connectedValidators: e.Config.ConnectedValidators.ConnectedValidators, + minConfirmationThreshold: minConfRatio, + lastAcceptedHeightByNodeID: e.acceptedFrontiers.LastAccepted, } conf := stragglerDetectorConfig{ @@ -52,12 +56,12 @@ func NewDecoratedEngineWithStragglerDetector(e *Engine, time func() time.Time, f } } -func (de *decoratedEngineWithStragglerDetector) Chits(ctx context.Context, nodeID ids.NodeID, requestID uint32, preferredID ids.ID, preferredIDAtHeight ids.ID, acceptedID ids.ID) error { +func (de *decoratedEngineWithStragglerDetector) Chits(ctx context.Context, nodeID ids.NodeID, requestID uint32, preferredID ids.ID, preferredIDAtHeight ids.ID, acceptedID ids.ID, acceptedHeight uint64) error { behindDuration := de.sd.CheckIfWeAreStragglingBehind() if behindDuration > 0 { de.Engine.Config.Ctx.Log.Info("We are behind the rest of the network", zap.Float64("seconds", behindDuration.Seconds())) } de.Engine.metrics.stragglingDuration.Set(float64(behindDuration)) de.f(behindDuration) - return de.Engine.Chits(ctx, nodeID, requestID, preferredID, preferredIDAtHeight, acceptedID) + return de.Engine.Chits(ctx, nodeID, requestID, preferredID, preferredIDAtHeight, acceptedID, acceptedHeight) } diff --git a/snow/engine/snowman/engine_decorator_test.go b/snow/engine/snowman/engine_decorator_test.go index c40664c1e051..2dd4ff1fd303 100644 --- a/snow/engine/snowman/engine_decorator_test.go +++ b/snow/engine/snowman/engine_decorator_test.go @@ -54,12 +54,12 @@ func TestEngineStragglerDetector(t *testing.T) { now := time.Now() fakeClock.Set(now) - require.NoError(decoratedEngine.Chits(context.Background(), peerID, 0, parent.ID(), parent.ID(), parent.ID())) + require.NoError(decoratedEngine.Chits(context.Background(), peerID, 0, parent.ID(), parent.ID(), parent.ID(), 100)) now = now.Add(minStragglerCheckInterval * 2) fakeClock.Set(now) - require.NoError(decoratedEngine.Chits(context.Background(), peerID, 0, parent.ID(), parent.ID(), parent.ID())) + require.NoError(decoratedEngine.Chits(context.Background(), peerID, 0, parent.ID(), parent.ID(), parent.ID(), 100)) now = now.Add(minStragglerCheckInterval * 2) fakeClock.Set(now) - require.NoError(decoratedEngine.Chits(context.Background(), peerID, 0, parent.ID(), parent.ID(), parent.ID())) + require.NoError(decoratedEngine.Chits(context.Background(), peerID, 0, parent.ID(), parent.ID(), parent.ID(), 100)) require.Empty(listenerShouldInvokeWith) } diff --git a/snow/engine/snowman/straggler_detect.go b/snow/engine/snowman/straggler_detect.go index 5a8525239368..01e90d58265e 100644 --- a/snow/engine/snowman/straggler_detect.go +++ b/snow/engine/snowman/straggler_detect.go @@ -4,6 +4,7 @@ package snowman import ( + "errors" "fmt" "time" @@ -21,7 +22,6 @@ const ( minStragglerCheckInterval = 10 * time.Second stakeThresholdForStragglerSuspicion = 0.75 minimumStakeThresholdRequiredForNetworkInfo = 0.8 - knownStakeThresholdRequiredForAnalysis = 0.8 ) type stragglerDetectorConfig struct { @@ -110,15 +110,13 @@ func (sd *stragglerDetector) evaluateSnapshot(now time.Time) { sd.prevSnapshot = snapshot{} } +// snapshotAnalyzer analyzes a snapshot and returns true whether +// the caller is behind the majority of the network type snapshotAnalyzer struct { + // log is used to log events log logging.Logger - - // processing returns whether this block ID is known and its descendants have not yet been accepted by consensus. - // This means that when the last accepted block is given as input, true is returned, as by definition - // its descendants have not been accepted by consensus, but this block is known. - // For any block ID belonging to an ancestor of the last accepted block, false is returned, - // as the last accepted block has been accepted by consensus. - processing func(id ids.ID) bool + // lastAcceptedHeight returns the last accepted block of this node. + lastAcceptedHeight func() uint64 } func (sa *snapshotAnalyzer) areWeBehindTheRest(s snapshot) bool { @@ -126,17 +124,18 @@ func (sa *snapshotAnalyzer) areWeBehindTheRest(s snapshot) bool { return false } - totalValidatorWeight, nodeWeightsToBlocks := s.totalValidatorWeight, s.lastAcceptedBlockID + totalValidatorWeight, nodeWeightsToBlockHeights := s.totalValidatorWeight, s.lastAcceptedBlockHeight - processingWeight, err := nodeWeightsToBlocks.filter(sa.processing).totalWeight() + filter := higherThanGivenHeight(sa.lastAcceptedHeight()) + totalWeightOfNodesAheadOfUs, err := nodeWeightsToBlockHeights.filter(filter).totalWeight() if err != nil { sa.log.Error("Failed computing total weight", zap.Error(err)) return false } - sa.log.Trace("Counted total weight that accepted blocks we're still processing", zap.Uint64("weight", processingWeight)) + sa.log.Trace("Counted total weight of nodes that are ahead of us", zap.Uint64("weight", totalWeightOfNodesAheadOfUs)) - ratio := float64(processingWeight) / float64(totalValidatorWeight) + ratio := float64(totalWeightOfNodesAheadOfUs) / float64(totalValidatorWeight) if ratio > stakeThresholdForStragglerSuspicion { sa.log.Trace("We are straggling behind", zap.Float64("ratio", ratio)) @@ -148,145 +147,101 @@ func (sa *snapshotAnalyzer) areWeBehindTheRest(s snapshot) bool { return false } +func higherThanGivenHeight(givenHeight uint64) func(height uint64) bool { + return func(height uint64) bool { + return height > givenHeight + } +} + type snapshotter struct { // log logs events log logging.Logger + // totalWeight returns the total amount of weight. + totalWeight func() (uint64, error) + // minConfirmationThreshold is the minimum stake percentage that below it, we do not check if we are stragglers. minConfirmationThreshold float64 - // lastAccepted returns the last accepted block of this node. - lastAccepted func() ids.ID - // connectedValidators returns a set of tuples of NodeID and corresponding weight. connectedValidators func() set.Set[ids.NodeWeight] - // lastAcceptedByNodeID returns the last accepted height a node has reported, or false if it is unknown. - lastAcceptedByNodeID func(id ids.NodeID) (ids.ID, bool) -} - -func (s *snapshotter) validateNetInfo(netInfo netInfo) bool { - if netInfo.totalValidatorWeight == 0 { - s.log.Trace("Connected to zero weight") - return false - } - - totalKnownLastBlockStakePercent := float64(netInfo.totalWeightWeKnowItsLastAcceptedBlock) / float64(netInfo.totalValidatorWeight) - stakeAheadOfUs := float64(netInfo.totalPendingStake) / float64(netInfo.totalValidatorWeight) - - // Ensure we have collected last accepted blocks for at least 80% (or so) stake of the total weight we are connected to. - if totalKnownLastBlockStakePercent < minimumStakeThresholdRequiredForNetworkInfo { - s.log.Trace("Not collected enough information about last accepted blocks for the validators we are connected to", - zap.Float64("ratio", totalKnownLastBlockStakePercent)) - return false - } - - if stakeAheadOfUs < knownStakeThresholdRequiredForAnalysis { - s.log.Trace("Most stake we're connected to has the same height as we do", - zap.Float64("ratio", stakeAheadOfUs)) - return false - } - - return true + // lastAcceptedHeightByNodeID returns the last accepted height a node has reported, or false if it is unknown. + lastAcceptedHeightByNodeID func(id ids.NodeID) (ids.ID, uint64, bool) } func (s *snapshotter) getNetworkSnapshot() (snapshot, bool) { - ourLastAcceptedBlock := s.lastAccepted() - - netInfo, err := s.getNetworkInfo(ourLastAcceptedBlock) + totalValidatorWeight, nodeWeightsToLastAcceptedHeight, err := s.getNetworkInfo() if err != nil { - return snapshot{}, false - } - - if !s.validateNetInfo(netInfo) { + s.log.Trace("Failed getting network info", zap.Error(err)) return snapshot{}, false } snap := snapshot{ - lastAcceptedBlockID: netInfo.nodeWeightToLastAccepted, - totalValidatorWeight: netInfo.totalValidatorWeight, + lastAcceptedBlockHeight: nodeWeightsToLastAcceptedHeight, + totalValidatorWeight: totalValidatorWeight, } return snap, true } -func (s *snapshotter) getNetworkInfo(ourLastAcceptedBlock ids.ID) (netInfo, error) { - var res netInfo +func (s *snapshotter) getNetworkInfo() (uint64, nodeWeightsToHeight, error) { + totalValidatorWeight, err := s.totalWeight() + if err != nil { + return 0, nil, err + } validators := nodeWeights(s.connectedValidators().List()) - nodeWeightToLastAccepted := make(nodeWeightsToBlocks, len(validators)) + nodeWeightToLastAcceptedHeight := make(nodeWeightsToHeight, len(validators)) for _, vdr := range validators { - lastAccepted, ok := s.lastAcceptedByNodeID(vdr.ID) + _, lastAcceptedHeight, ok := s.lastAcceptedHeightByNodeID(vdr.ID) if !ok { continue } - nodeWeightToLastAccepted[vdr] = lastAccepted + nodeWeightToLastAcceptedHeight[vdr] = lastAcceptedHeight } - totalValidatorWeight, err := validators.totalWeight() + totalKnownConnectedWeight, err := nodeWeightToLastAcceptedHeight.totalWeight() if err != nil { - s.log.Error("Failed computing total weight", zap.Error(err)) - return netInfo{}, err + return 0, nil, err } - res.totalValidatorWeight = totalValidatorWeight - totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeightToLastAccepted.totalWeight() - if err != nil { - s.log.Error("Failed computing total weight", zap.Error(err)) - return netInfo{}, err + if totalValidatorWeight == 0 { + return 0, nil, errors.New("connected to zero weight") } - res.totalWeightWeKnowItsLastAcceptedBlock = totalWeightWeKnowItsLastAcceptedBlock - - prevLastAcceptedCount := len(nodeWeightToLastAccepted) - // Ensure we have collected last accepted blocks that are not our own last accepted block. - nodeWeightToLastAccepted = nodeWeightToLastAccepted.filter(func(id ids.ID) bool { - return ourLastAcceptedBlock.Compare(id) != 0 - }) + knownPercentageOfConnectedValidators := 100 * float64(totalKnownConnectedWeight) / float64(totalValidatorWeight) - res.nodeWeightToLastAccepted = nodeWeightToLastAccepted - - totalPendingStake, err := nodeWeightToLastAccepted.totalWeight() - if err != nil { - s.log.Error("Failed computing total weight", zap.Error(err)) - return netInfo{}, err + if knownPercentageOfConnectedValidators < 100*minimumStakeThresholdRequiredForNetworkInfo { + s.log.Trace("Not collected enough information about last accepted block heights", + zap.Int("percentage", int(knownPercentageOfConnectedValidators))) + return 0, nil, errors.New("not enough information") } - res.totalPendingStake = totalPendingStake - - s.log.Trace("Excluding nodes with our own height", zap.Int("prev", prevLastAcceptedCount), zap.Uint64("new", totalPendingStake)) - - return res, nil -} - -type netInfo struct { - totalPendingStake uint64 - totalValidatorWeight uint64 - totalWeightWeKnowItsLastAcceptedBlock uint64 - nodeWeightToLastAccepted nodeWeightsToBlocks + return totalValidatorWeight, nodeWeightToLastAcceptedHeight, nil } type snapshot struct { - totalValidatorWeight uint64 - lastAcceptedBlockID nodeWeightsToBlocks + totalValidatorWeight uint64 + lastAcceptedBlockHeight nodeWeightsToHeight } func (s snapshot) isEmpty() bool { - return s.totalValidatorWeight == 0 || len(s.lastAcceptedBlockID) == 0 + return s.totalValidatorWeight == 0 || len(s.lastAcceptedBlockHeight) == 0 } -type nodeWeightsToBlocks map[ids.NodeWeight]ids.ID +type nodeWeightsToHeight map[ids.NodeWeight]uint64 -func (nwb nodeWeightsToBlocks) totalWeight() (uint64, error) { - return nodeWeights(maps.Keys(nwb)).totalWeight() +func (nwh nodeWeightsToHeight) totalWeight() (uint64, error) { + return nodeWeights(maps.Keys(nwh)).totalWeight() } -// dropHeight removes the second return parameter from the function f() and keeps its first return parameter, ids.ID. -func dropHeight(f func() (ids.ID, uint64)) func() ids.ID { - return func() ids.ID { - id, _ := f() - return id +// onlyHeight removes the first return parameter from the function f() and keeps its second return parameter, the height. +func onlyHeight(f func() (ids.ID, uint64)) func() uint64 { + return func() uint64 { + _, height := f() + return height } } @@ -304,11 +259,11 @@ func (nws nodeWeights) totalWeight() (uint64, error) { return weight, nil } -func (nwb nodeWeightsToBlocks) filter(f func(ids.ID) bool) nodeWeightsToBlocks { - filtered := make(nodeWeightsToBlocks, len(nwb)) - for nw, id := range nwb { - if f(id) { - filtered[nw] = id +func (nwh nodeWeightsToHeight) filter(f func(height uint64) bool) nodeWeightsToHeight { + filtered := make(nodeWeightsToHeight, len(nwh)) + for nw, height := range nwh { + if f(height) { + filtered[nw] = height } } return filtered diff --git a/snow/engine/snowman/straggler_detect_test.go b/snow/engine/snowman/straggler_detect_test.go index 63d072ae770d..ade29dc7c12e 100644 --- a/snow/engine/snowman/straggler_detect_test.go +++ b/snow/engine/snowman/straggler_detect_test.go @@ -4,7 +4,6 @@ package snowman import ( - "fmt" "math" "testing" "time" @@ -42,9 +41,9 @@ func TestNodeWeightsOverflow(t *testing.T) { } func TestNodeWeights2Blocks(t *testing.T) { - nw2b := nodeWeightsToBlocks{ - ids.NodeWeight{Weight: 5}: ids.Empty, - ids.NodeWeight{Weight: 10}: ids.Empty, + nw2b := nodeWeightsToHeight{ + ids.NodeWeight{Weight: 5}: 100, + ids.NodeWeight{Weight: 10}: 101, } total, err := nw2b.totalWeight() @@ -68,10 +67,10 @@ func TestGetNetworkSnapshot(t *testing.T) { } for _, testCase := range []struct { + totalWeight uint64 description string - lastAccepted ids.ID - lastAcceptedFromNodes map[ids.NodeID]ids.ID - processing map[ids.ID]struct{} + lastAcceptedHeight uint64 + lastAcceptedFromNodes map[ids.NodeID]uint64 connectedValidators func() set.Set[ids.NodeWeight] expectedSnapshot snapshot expectedOK bool @@ -80,48 +79,53 @@ func TestGetNetworkSnapshot(t *testing.T) { { description: "connected to zero weight", connectedValidators: connectedValidators([]ids.NodeWeight{}), - expectedLogged: "Connected to zero weight", + expectedLogged: "connected to zero weight", }, { description: "not enough info", + totalWeight: 9999999999, connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 1, ID: n1}, {Weight: 999999, ID: n2}}), - lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ - n1: {0x1}, + lastAcceptedFromNodes: map[ids.NodeID]uint64{ + n1: 100, }, - expectedLogged: "Not collected enough information about last accepted blocks", + expectedLogged: "Not collected enough information about last accepted block", }, { description: "we're in sync", + totalWeight: 999999, connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, ID: n1}}), - lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ - n1: {0x1}, + lastAcceptedFromNodes: map[ids.NodeID]uint64{ + n1: 100, }, - lastAccepted: ids.ID{0x1}, - expectedLogged: "Most stake we're connected to has the same height as we do", + lastAcceptedHeight: 100, + expectedOK: true, + expectedSnapshot: snapshot{totalValidatorWeight: 999999, lastAcceptedBlockHeight: nodeWeightsToHeight{ + {Weight: 999999, ID: n1}: 100, + }}, }, { description: "we're behind", + totalWeight: 999999, connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, ID: n1}}), - lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ - n1: {0x1}, + lastAcceptedFromNodes: map[ids.NodeID]uint64{ + n1: 120, }, - processing: map[ids.ID]struct{}{{0x1}: {}}, - lastAccepted: ids.ID{0x0}, - expectedSnapshot: snapshot{totalValidatorWeight: 999999, lastAcceptedBlockID: nodeWeightsToBlocks{ - ids.NodeWeight{ID: n1, Weight: 999999}: {0x1}, + lastAcceptedHeight: 100, + expectedSnapshot: snapshot{totalValidatorWeight: 999999, lastAcceptedBlockHeight: nodeWeightsToHeight{ + ids.NodeWeight{ID: n1, Weight: 999999}: 120, }}, expectedOK: true, }, { description: "we're not behind", + totalWeight: 999999, connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, ID: n1}}), - lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ - n1: {0x1}, + lastAcceptedFromNodes: map[ids.NodeID]uint64{ + n1: 100, }, - processing: map[ids.ID]struct{}{{0x2}: {}}, - lastAccepted: ids.ID{0x0}, - expectedSnapshot: snapshot{totalValidatorWeight: 999999, lastAcceptedBlockID: nodeWeightsToBlocks{ - ids.NodeWeight{ID: n1, Weight: 999999}: {0x1}, + lastAcceptedHeight: 100, + expectedSnapshot: snapshot{totalValidatorWeight: 999999, lastAcceptedBlockHeight: nodeWeightsToHeight{ + ids.NodeWeight{ID: n1, Weight: 999999}: 100, }}, expectedOK: true, }, @@ -131,15 +135,15 @@ func TestGetNetworkSnapshot(t *testing.T) { log := logging.NewLogger("", logging.NewWrappedCore(logging.Verbo, &buff, logging.Plain.ConsoleEncoder())) s := &snapshotter{ + totalWeight: func() (uint64, error) { + return testCase.totalWeight, nil + }, log: log, connectedValidators: testCase.connectedValidators, minConfirmationThreshold: 0.75, - lastAccepted: func() ids.ID { - return testCase.lastAccepted - }, - lastAcceptedByNodeID: func(vdr ids.NodeID) (ids.ID, bool) { - id, ok := testCase.lastAcceptedFromNodes[vdr] - return id, ok + lastAcceptedHeightByNodeID: func(vdr ids.NodeID) (ids.ID, uint64, bool) { + height, ok := testCase.lastAcceptedFromNodes[vdr] + return ids.Empty, height, ok }, } @@ -158,9 +162,8 @@ func TestFailedCatchingUp(t *testing.T) { for _, testCase := range []struct { description string - lastAccepted ids.ID + lastAccepted uint64 lastAcceptedFromNodes map[ids.NodeID]ids.ID - processing map[ids.ID]struct{} connectedValidators []ids.NodeWeight input snapshot expected bool @@ -170,57 +173,33 @@ func TestFailedCatchingUp(t *testing.T) { description: "stake overflow", input: snapshot{ totalValidatorWeight: 100, - lastAcceptedBlockID: nodeWeightsToBlocks{ - ids.NodeWeight{ID: n1, Weight: math.MaxUint64 - 10}: ids.ID{0x1}, - ids.NodeWeight{ID: n2, Weight: 11}: ids.ID{0x2}, + lastAcceptedBlockHeight: nodeWeightsToHeight{ + ids.NodeWeight{ID: n1, Weight: math.MaxUint64 - 10}: 100, + ids.NodeWeight{ID: n2, Weight: 11}: 100, }, }, - processing: map[ids.ID]struct{}{ - {0x1}: {}, - {0x2}: {}, - }, expectedLogged: "Failed computing total weight", }, - { - description: "Straggling behind stake minority", - input: snapshot{ - totalValidatorWeight: 100, lastAcceptedBlockID: nodeWeightsToBlocks{ - ids.NodeWeight{ID: n1, Weight: 25}: ids.ID{0x1}, - ids.NodeWeight{ID: n2, Weight: 50}: ids.ID{0x2}, - }, - }, - processing: map[ids.ID]struct{}{ - {0x1}: {}, - {0x2}: {}, - }, - expectedLogged: "Nodes ahead of us", - }, { description: "Straggling behind stake majority", input: snapshot{ - totalValidatorWeight: 100, lastAcceptedBlockID: nodeWeightsToBlocks{ - ids.NodeWeight{ID: n1, Weight: 26}: ids.ID{0x1}, - ids.NodeWeight{ID: n2, Weight: 50}: ids.ID{0x2}, + totalValidatorWeight: 100, lastAcceptedBlockHeight: nodeWeightsToHeight{ + ids.NodeWeight{ID: n1, Weight: 26}: 100, + ids.NodeWeight{ID: n2, Weight: 50}: 100, }, }, - processing: map[ids.ID]struct{}{ - {0x1}: {}, - {0x2}: {}, - }, expectedLogged: "We are straggling behind", expected: true, }, { - description: "In sync with the majority", + description: "In sync with the majority", + lastAccepted: 100, input: snapshot{ - totalValidatorWeight: 100, lastAcceptedBlockID: nodeWeightsToBlocks{ - ids.NodeWeight{ID: n1, Weight: 75}: ids.ID{0x1}, - ids.NodeWeight{ID: n2, Weight: 25}: ids.ID{0x2}, + totalValidatorWeight: 100, lastAcceptedBlockHeight: nodeWeightsToHeight{ + ids.NodeWeight{ID: n1, Weight: 75}: 100, + ids.NodeWeight{ID: n2, Weight: 25}: 100, }, }, - processing: map[ids.ID]struct{}{ - {0x2}: {}, - }, expectedLogged: "Nodes ahead of us", }, } { @@ -229,11 +208,10 @@ func TestFailedCatchingUp(t *testing.T) { log := logging.NewLogger("", logging.NewWrappedCore(logging.Verbo, &buff, logging.Plain.ConsoleEncoder())) sa := &snapshotAnalyzer{ - log: log, - processing: func(id ids.ID) bool { - _, ok := testCase.processing[id] - return ok + lastAcceptedHeight: func() uint64 { + return testCase.lastAccepted }, + log: log, } require.Equal(t, testCase.expected, sa.areWeBehindTheRest(testCase.input)) @@ -255,13 +233,11 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { } nonEmptySnap := snapshot{ totalValidatorWeight: 100, - lastAcceptedBlockID: nodeWeightsToBlocks{ - ids.NodeWeight{Weight: 100}: ids.Empty, + lastAcceptedBlockHeight: nodeWeightsToHeight{ + ids.NodeWeight{Weight: 100}: 100, }, } - var haveWeFailedCatchingUpReturns bool - var buff logBuffer log := logging.NewLogger("", logging.NewWrappedCore(logging.Verbo, &buff, logging.Plain.ConsoleEncoder())) @@ -274,9 +250,6 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { s := <-snapshots return s, !s.isEmpty() }, - areWeBehindTheRest: func(_ snapshot) bool { - return haveWeFailedCatchingUpReturns - }, }, } @@ -285,25 +258,25 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { for _, testCase := range []struct { description string timeAdvanced time.Duration - evalExtraAssertions func(t *testing.T) + evalExtraAssertions func(*testing.T, *stragglerDetector) expectedStragglingTime time.Duration snapshotsRead []snapshot haveWeFailedCatchingUpReturns bool }{ { description: "First invocation only sets the time", - evalExtraAssertions: func(_ *testing.T) {}, + evalExtraAssertions: func(_ *testing.T, _ *stragglerDetector) {}, }, { description: "Should not check yet, as it is not time yet", timeAdvanced: time.Millisecond * 500, - evalExtraAssertions: func(_ *testing.T) {}, + evalExtraAssertions: func(_ *testing.T, _ *stragglerDetector) {}, }, { description: "Advance time some more, so now we should check", timeAdvanced: time.Millisecond * 501, snapshotsRead: []snapshot{{}}, - evalExtraAssertions: func(t *testing.T) { + evalExtraAssertions: func(t *testing.T, _ *stragglerDetector) { require.Contains(t, buff.String(), "No node snapshot obtained") }, }, @@ -312,7 +285,7 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { timeAdvanced: time.Second * 2, snapshotsRead: []snapshot{nonEmptySnap}, haveWeFailedCatchingUpReturns: true, - evalExtraAssertions: func(t *testing.T) { + evalExtraAssertions: func(t *testing.T, _ *stragglerDetector) { require.Empty(t, buff.String()) }, }, @@ -321,7 +294,7 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { timeAdvanced: time.Second * 2, expectedStragglingTime: time.Second * 2, haveWeFailedCatchingUpReturns: true, - evalExtraAssertions: func(t *testing.T) { + evalExtraAssertions: func(t *testing.T, sd *stragglerDetector) { require.Empty(t, sd.prevSnapshot) }, }, @@ -333,16 +306,15 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { // We carry over the total straggling time from previous testCase to this check, // as we expect the next check to nullify it. expectedStragglingTime: time.Second * 2, - evalExtraAssertions: func(_ *testing.T) {}, + evalExtraAssertions: func(_ *testing.T, _ *stragglerDetector) {}, }, { description: "The fourth check returns we have succeeded in catching up", timeAdvanced: time.Second * 2, - evalExtraAssertions: func(_ *testing.T) {}, + evalExtraAssertions: func(_ *testing.T, _ *stragglerDetector) {}, }, } { t.Run(testCase.description, func(t *testing.T) { - fmt.Println(testCase.description) fakeTime = fakeTime.Add(testCase.timeAdvanced) fakeClock.Set(fakeTime) @@ -351,14 +323,16 @@ func TestCheckIfWeAreStragglingBehind(t *testing.T) { snapshots <- testCase.snapshotsRead[0] } - haveWeFailedCatchingUpReturns = testCase.haveWeFailedCatchingUpReturns + sd.areWeBehindTheRest = func(_ snapshot) bool { + return testCase.haveWeFailedCatchingUpReturns + } + require.Equal(t, testCase.expectedStragglingTime, sd.CheckIfWeAreStragglingBehind()) - testCase.evalExtraAssertions(t) + testCase.evalExtraAssertions(t, &sd) // Cleanup the log buffer, and make sure no snapshots remain for next testCase. buff.Reset() assertNoSnapshotsRemain() - haveWeFailedCatchingUpReturns = false }) } }