From 91cc26e68774d89a4c56717dfdf6b0f4fbe564ae Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Wed, 11 Sep 2024 23:20:34 +0200 Subject: [PATCH] Detect if our node is behind the majority This commit adds a mechanism that detects that our node is behind the majority of the stake. The intent is to later have this mechanism be the trigger for the bootstrapping mechanism. Currently, the bootstrapping mechanism is only active upon node boot, but not at a later point. The mechanism works in the following manner: - It intercepts the snowman engine's Chits message handling, and upon every reception of the Chits message, the mechanism that detects if the node is a straggler (a node with a ledger height behind the rest) may be invoked, if it wasn't invoked too recently. - The mechanism draws statistics from the validators known to it, and computes the latest accepted block for each validator. - The mechanism then proceeds to determine which blocks are pending to be processed (a block pending to be processed was not accepted). - The mechanism then collects a snapshot of all blocks it hasn't accepted yet, and the amount of stake that has accepted this block. - The mechanism then waits for its next invocation, in order to see if it has accepted blocks correlated with enough stake. - If there is too much stake that has accepted blocks by other nodes correlated to it that the node hasn't accepted, then the mechanism announces the node is behind, and returns the time period between the two invocations. - The mechanism sums the total time it has detected the node is behind, until a sampling concludes it is not behind, and then the total time is nullified. Signed-off-by: Yacov Manevich --- chains/manager.go | 11 +- ids/node_weight.go | 9 + snow/engine/common/tracker/peers.go | 26 +- snow/engine/common/tracker/peers_test.go | 25 ++ snow/engine/snowman/engine_decorator.go | 59 +++ snow/engine/snowman/engine_decorator_test.go | 75 ++++ snow/engine/snowman/metrics.go | 6 + snow/engine/snowman/straggler_detect.go | 307 +++++++++++++++ snow/engine/snowman/straggler_detect_test.go | 376 +++++++++++++++++++ snow/networking/handler/health.go | 10 +- 10 files changed, 891 insertions(+), 13 deletions(-) create mode 100644 ids/node_weight.go create mode 100644 snow/engine/snowman/engine_decorator.go create mode 100644 snow/engine/snowman/engine_decorator_test.go create mode 100644 snow/engine/snowman/straggler_detect.go create mode 100644 snow/engine/snowman/straggler_detect_test.go diff --git a/chains/manager.go b/chains/manager.go index 2fdd3c0b73f9..b2812bcdadf5 100644 --- a/chains/manager.go +++ b/chains/manager.go @@ -1339,12 +1339,19 @@ func (m *manager) createSnowmanChain( Consensus: consensus, PartialSync: m.PartialSyncPrimaryNetwork && ctx.ChainID == constants.PlatformChainID, } - var engine common.Engine - engine, err = smeng.New(engineConfig) + + sme, err := smeng.New(engineConfig) if err != nil { return nil, fmt.Errorf("error initializing snowman engine: %w", err) } + ed := smeng.EngineStragglerDetector{ + Listener: func(_ time.Duration) {}, + Time: time.Now, + } + + engine := ed.AttachToEngine(sme) + if m.TracingEnabled { engine = common.TraceEngine(engine, m.Tracer) } diff --git a/ids/node_weight.go b/ids/node_weight.go new file mode 100644 index 000000000000..21309586ca2a --- /dev/null +++ b/ids/node_weight.go @@ -0,0 +1,9 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package ids + +type NodeWeight struct { + Node NodeID + Weight uint64 +} diff --git a/snow/engine/common/tracker/peers.go b/snow/engine/common/tracker/peers.go index 37bf7b10f026..65dda6f7d1ff 100644 --- a/snow/engine/common/tracker/peers.go +++ b/snow/engine/common/tracker/peers.go @@ -9,7 +9,6 @@ import ( "sync" "github.com/prometheus/client_golang/prometheus" - "golang.org/x/exp/maps" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/validators" @@ -37,10 +36,10 @@ type Peers interface { SampleValidator() (ids.NodeID, bool) // GetValidators returns the set of all validators // known to this peer manager - GetValidators() set.Set[ids.NodeID] + GetValidators() set.Set[ids.NodeWeight] // ConnectedValidators returns the set of all validators // that are currently connected - ConnectedValidators() set.Set[ids.NodeID] + ConnectedValidators() set.Set[ids.NodeWeight] } type lockedPeers struct { @@ -112,14 +111,14 @@ func (p *lockedPeers) SampleValidator() (ids.NodeID, bool) { return p.peers.SampleValidator() } -func (p *lockedPeers) GetValidators() set.Set[ids.NodeID] { +func (p *lockedPeers) GetValidators() set.Set[ids.NodeWeight] { p.lock.RLock() defer p.lock.RUnlock() return p.peers.GetValidators() } -func (p *lockedPeers) ConnectedValidators() set.Set[ids.NodeID] { +func (p *lockedPeers) ConnectedValidators() set.Set[ids.NodeWeight] { p.lock.RLock() defer p.lock.RUnlock() @@ -272,14 +271,21 @@ func (p *peerData) SampleValidator() (ids.NodeID, bool) { return p.connectedValidators.Peek() } -func (p *peerData) GetValidators() set.Set[ids.NodeID] { - return set.Of(maps.Keys(p.validators)...) +func (p *peerData) GetValidators() set.Set[ids.NodeWeight] { + res := set.NewSet[ids.NodeWeight](len(p.validators)) + for k, v := range p.validators { + res.Add(ids.NodeWeight{Node: k, Weight: v}) + } + return res } -func (p *peerData) ConnectedValidators() set.Set[ids.NodeID] { +func (p *peerData) ConnectedValidators() set.Set[ids.NodeWeight] { // The set is copied to avoid future changes from being reflected in the // returned set. - copied := set.NewSet[ids.NodeID](len(p.connectedValidators)) - copied.Union(p.connectedValidators) + copied := set.NewSet[ids.NodeWeight](len(p.connectedValidators)) + for _, vdrID := range p.connectedValidators.List() { + weight := p.validators[vdrID] + copied.Add(ids.NodeWeight{Node: vdrID, Weight: weight}) + } return copied } diff --git a/snow/engine/common/tracker/peers_test.go b/snow/engine/common/tracker/peers_test.go index 1ed2daf6575b..ac577a399b48 100644 --- a/snow/engine/common/tracker/peers_test.go +++ b/snow/engine/common/tracker/peers_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/version" ) @@ -40,3 +41,27 @@ func TestPeers(t *testing.T) { require.NoError(p.Disconnected(context.Background(), nodeID)) require.Zero(p.ConnectedWeight()) } + +func TestConnectedValidators(t *testing.T) { + require := require.New(t) + + nodeID1 := ids.GenerateTestNodeID() + nodeID2 := ids.GenerateTestNodeID() + + p := NewPeers() + + p.OnValidatorAdded(nodeID1, nil, ids.Empty, 5) + p.OnValidatorAdded(nodeID2, nil, ids.Empty, 6) + + require.NoError(p.Connected(context.Background(), nodeID1, version.CurrentApp)) + require.Equal(uint64(5), p.ConnectedWeight()) + + require.NoError(p.Connected(context.Background(), nodeID2, version.CurrentApp)) + require.Equal(uint64(11), p.ConnectedWeight()) + require.True(set.Of(ids.NodeWeight{Node: nodeID1, Weight: 5}, ids.NodeWeight{Node: nodeID2, Weight: 6}).Equals(p.GetValidators())) + require.True(set.Of(ids.NodeWeight{Node: nodeID1, Weight: 5}, ids.NodeWeight{Node: nodeID2, Weight: 6}).Equals(p.ConnectedValidators())) + + require.NoError(p.Disconnected(context.Background(), nodeID2)) + require.True(set.Of(ids.NodeWeight{Node: nodeID1, Weight: 5}, ids.NodeWeight{Node: nodeID2, Weight: 6}).Equals(p.GetValidators())) + require.True(set.Of(ids.NodeWeight{Node: nodeID1, Weight: 5}).Equals(p.ConnectedValidators())) +} diff --git a/snow/engine/snowman/engine_decorator.go b/snow/engine/snowman/engine_decorator.go new file mode 100644 index 000000000000..cd799347ff00 --- /dev/null +++ b/snow/engine/snowman/engine_decorator.go @@ -0,0 +1,59 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package snowman + +import ( + "context" + "time" + + "go.uber.org/zap" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/engine/common" +) + +type EngineStragglerDetector struct { + Listener func(duration time.Duration) + Time func() time.Time +} + +func (ed *EngineStragglerDetector) AttachToEngine(e *Engine) common.Engine { + minConfRatio := float64(e.Params.AlphaConfidence) / float64(e.Params.K) + sd := newStragglerDetector(ed.Time, e.Config.Ctx.Log, minConfRatio, e.Consensus.LastAccepted, + e.Config.ConnectedValidators.ConnectedValidators, e.Config.ConnectedValidators.ConnectedPercent, + e.Consensus.Processing, e.acceptedFrontiers.LastAccepted) + de := &DecoratedEngine{Engine: e} + de.decorate("Chits", func(e *Engine) { + behindDuration := sd.CheckIfWeAreStragglingBehind() + if behindDuration > 0 { + e.Config.Ctx.Log.Info("We are behind the rest of the network", zap.Float64("seconds", behindDuration.Seconds())) + } + e.metrics.stragglingDuration.Set(float64(behindDuration)) + ed.Listener(behindDuration) + }) + + return de +} + +type DecoratedEngine struct { + decorations map[string]func(*Engine) + + *Engine +} + +func (de *DecoratedEngine) decorate(method string, f func(*Engine)) { + if de.decorations == nil { + de.decorations = map[string]func(*Engine){} + } + de.decorations[method] = f +} + +func (de *DecoratedEngine) Chits(ctx context.Context, nodeID ids.NodeID, requestID uint32, preferredID ids.ID, preferredIDAtHeight ids.ID, acceptedID ids.ID) error { + f, ok := de.decorations["Chits"] + if !ok { + panic("programming error: decorator for Chits not registered") + } + f(de.Engine) + return de.Engine.Chits(ctx, nodeID, requestID, preferredID, preferredIDAtHeight, acceptedID) +} diff --git a/snow/engine/snowman/engine_decorator_test.go b/snow/engine/snowman/engine_decorator_test.go new file mode 100644 index 000000000000..4f5dfd05cec0 --- /dev/null +++ b/snow/engine/snowman/engine_decorator_test.go @@ -0,0 +1,75 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package snowman + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/consensus/snowman" + "github.com/ava-labs/avalanchego/snow/consensus/snowman/snowmantest" +) + +func TestEngineStragglerDetector(t *testing.T) { + require := require.New(t) + + fakeClock := make(chan time.Time, 1) + + conf := DefaultConfig(t) + peerID, _, sender, vm, engine := setup(t, conf) + + parent := snowmantest.BuildChild(snowmantest.Genesis) + require.NoError(conf.Consensus.Add(parent)) + + listenerShouldInvokeWith := []time.Duration{0, 0, time.Second * 2} + + esd := &EngineStragglerDetector{ + Listener: func(duration time.Duration) { + require.Equal(listenerShouldInvokeWith[0], duration) + listenerShouldInvokeWith = listenerShouldInvokeWith[1:] + }, + Time: func() time.Time { + select { + case now := <-fakeClock: + return now + default: + require.Fail("should have a time.Time in the channel") + return time.Time{} + } + }, + } + + decoratedEngine := esd.AttachToEngine(engine) + + vm.GetBlockF = func(_ context.Context, blkID ids.ID) (snowman.Block, error) { + switch blkID { + case snowmantest.GenesisID: + return snowmantest.Genesis, nil + default: + return nil, errUnknownBlock + } + } + + sender.SendGetF = func(_ context.Context, _ ids.NodeID, _ uint32, _ ids.ID) { + } + vm.ParseBlockF = func(_ context.Context, _ []byte) (snowman.Block, error) { + require.FailNow("should not be called") + return nil, nil + } + + now := time.Now() + fakeClock <- now + require.NoError(decoratedEngine.Chits(context.Background(), peerID, 0, parent.ID(), parent.ID(), parent.ID())) + now = now.Add(time.Second * 2) + fakeClock <- now + require.NoError(decoratedEngine.Chits(context.Background(), peerID, 0, parent.ID(), parent.ID(), parent.ID())) + now = now.Add(time.Second * 2) + fakeClock <- now + require.NoError(decoratedEngine.Chits(context.Background(), peerID, 0, parent.ID(), parent.ID(), parent.ID())) + require.Empty(listenerShouldInvokeWith) +} diff --git a/snow/engine/snowman/metrics.go b/snow/engine/snowman/metrics.go index 922b18200d47..68856ba1054b 100644 --- a/snow/engine/snowman/metrics.go +++ b/snow/engine/snowman/metrics.go @@ -23,6 +23,7 @@ type metrics struct { numBlocked prometheus.Gauge numBlockers prometheus.Gauge numNonVerifieds prometheus.Gauge + stragglingDuration prometheus.Gauge numBuilt prometheus.Counter numBuildsFailed prometheus.Counter numUselessPutBytes prometheus.Counter @@ -41,6 +42,10 @@ type metrics struct { func newMetrics(reg prometheus.Registerer) (*metrics, error) { errs := wrappers.Errs{} m := &metrics{ + stragglingDuration: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "straggling_duration", + Help: "For how long we have been straggling behind the rest, in nano-seconds.", + }), bootstrapFinished: prometheus.NewGauge(prometheus.GaugeOpts{ Name: "bootstrap_finished", Help: "Whether or not bootstrap process has completed. 1 is success, 0 is fail or ongoing.", @@ -128,6 +133,7 @@ func newMetrics(reg prometheus.Registerer) (*metrics, error) { m.issued.WithLabelValues(unknownSource) errs.Add( + reg.Register(m.stragglingDuration), reg.Register(m.bootstrapFinished), reg.Register(m.numRequests), reg.Register(m.numBlocked), diff --git a/snow/engine/snowman/straggler_detect.go b/snow/engine/snowman/straggler_detect.go new file mode 100644 index 000000000000..736aa9f4100b --- /dev/null +++ b/snow/engine/snowman/straggler_detect.go @@ -0,0 +1,307 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package snowman + +import ( + "fmt" + "time" + + "go.uber.org/zap" + "golang.org/x/exp/maps" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/set" + + safemath "github.com/ava-labs/avalanchego/utils/math" +) + +const ( + minStragglerCheckInterval = time.Second + stakeThresholdForStragglerSuspicion = 0.75 +) + +type stragglerDetectorConfig struct { + // getTime returns the current time + getTime func() time.Time + + // minStragglerCheckInterval determines how frequently we are allowed to check if we are stragglers. + minStragglerCheckInterval time.Duration + + // log logs events + log logging.Logger + + // minConfirmationThreshold is the minimum percentage that below it, we do not check if we are stragglers. + minConfirmationThreshold float64 + + // connectedPercent returns the percent of connected nodes. + connectedPercent func() float64 + + // connectedValidators returns a set of tuples of NodeID and corresponding weight. + connectedValidators func() set.Set[ids.NodeWeight] + + // lastAcceptedByNodeID returns the last reported height a node has reported, or false if it is unknown. + lastAcceptedByNodeID func(id ids.NodeID) (ids.ID, bool) + + // processing returns whether this block ID is known and its descendants have not yet been accepted by consensus. + // This means that when the last accepted block is given as input, true is returned, as by definition + // its descendants have not been accepted by consensus, but this block is known. + // For any block ID belonging to an ancestor of the last accepted block, false is returned, + // as the last accepted block has been accepted by consensus. + processing func(id ids.ID) bool + + // lastAccepted returns the last accepted block of this node. + lastAccepted func() ids.ID + + // getSnapshot returns a snapshot of the network's nodes and their last accepted blocks, + // or false if it fails from some reason. + getSnapshot func() (snapshot, bool) + + // haveWeFailedCatchingUp returns whether we have not replicated enough blocks of the given snapshot + haveWeFailedCatchingUp func(snapshot) bool +} + +type stragglerDetector struct { + stragglerDetectorConfig + + // continuousStragglingPeriod defines the time we have been straggling continuously. + continuousStragglingPeriod time.Duration + + // previousStragglerCheckTime is the last time we checked whether + // our block height is behind the rest of the network + previousStragglerCheckTime time.Time + + // prevSnapshot is the snapshot from a past iteration. + prevSnapshot snapshot +} + +func newStragglerDetector( + getTime func() time.Time, + log logging.Logger, + minConfirmationThreshold float64, + lastAccepted func() (ids.ID, uint64), + connectedValidators func() set.Set[ids.NodeWeight], + connectedPercent func() float64, + processing func(id ids.ID) bool, + lastAcceptedByNodeID func(ids.NodeID) (ids.ID, bool), +) *stragglerDetector { + sd := &stragglerDetector{ + stragglerDetectorConfig: stragglerDetectorConfig{ + lastAccepted: dropHeight(lastAccepted), + processing: processing, + minStragglerCheckInterval: minStragglerCheckInterval, + log: log, + connectedValidators: connectedValidators, + connectedPercent: connectedPercent, + minConfirmationThreshold: minConfirmationThreshold, + lastAcceptedByNodeID: lastAcceptedByNodeID, + getTime: getTime, + }, + } + + sd.getSnapshot = sd.getNetworkSnapshot + sd.haveWeFailedCatchingUp = sd.failedCatchingUp + + return sd +} + +// CheckIfWeAreStragglingBehind returns for how long our ledger is behind the rest +// of the nodes in the network. If we are not behind, zero is returned. +func (sd *stragglerDetector) CheckIfWeAreStragglingBehind() time.Duration { + now := sd.getTime() + if sd.previousStragglerCheckTime.IsZero() { + sd.previousStragglerCheckTime = now + return 0 + } + + // Don't check too often, only once in every minStragglerCheckInterval + if sd.previousStragglerCheckTime.Add(sd.minStragglerCheckInterval).After(now) { + return 0 + } + + defer func() { + sd.previousStragglerCheckTime = now + }() + + if sd.prevSnapshot.isEmpty() { + snapshot, ok := sd.getSnapshot() + if !ok { + sd.log.Trace("No node snapshot obtained") + sd.continuousStragglingPeriod = 0 + } + sd.prevSnapshot = snapshot + } else { + if sd.haveWeFailedCatchingUp(sd.prevSnapshot) { + timeSinceLastCheck := now.Sub(sd.previousStragglerCheckTime) + sd.continuousStragglingPeriod += timeSinceLastCheck + } else { + sd.continuousStragglingPeriod = 0 + } + sd.prevSnapshot = snapshot{} + } + + return sd.continuousStragglingPeriod +} + +func (sd *stragglerDetector) failedCatchingUp(s snapshot) bool { + totalValidatorWeight, nodeWeights2Blocks := s.totalValidatorWeight, s.nodeWeights2Blocks + + var processingWeight uint64 + for nw, lastAccepted := range nodeWeights2Blocks { + if sd.processing(lastAccepted) { + newProcessingWeight, err := safemath.Add(processingWeight, nw.Weight) + if err != nil { + sd.log.Error("Cumulative weight overflow", zap.Uint64("cumulative", processingWeight), zap.Uint64("added", nw.Weight)) + return false + } + processingWeight = newProcessingWeight + } + } + + sd.log.Trace("Counted total weight that accepted blocks we're still processing", zap.Uint64("weight", processingWeight)) + + ratio := float64(processingWeight) / float64(totalValidatorWeight) + + if ratio > stakeThresholdForStragglerSuspicion { + sd.log.Trace("We are straggling behind", zap.Float64("ratio", ratio)) + return true + } + + sd.log.Trace("Nodes ahead of us:", zap.Float64("ratio", ratio)) + + return false +} + +func (sd *stragglerDetector) getNetworkSnapshot() (snapshot, bool) { + nodeWeight2lastAccepted, totalValidatorWeight, _ := sd.getNetworkInfo() + if len(nodeWeight2lastAccepted) == 0 { + return snapshot{}, false + } + + ourLastAcceptedBlock := sd.lastAccepted() + + prevLastAcceptedCount := len(nodeWeight2lastAccepted) + for k, v := range nodeWeight2lastAccepted { + if ourLastAcceptedBlock.Compare(v) == 0 { + delete(nodeWeight2lastAccepted, k) + } + } + newLastAcceptedCount := len(nodeWeight2lastAccepted) + + sd.log.Trace("Excluding nodes with our own height", zap.Int("prev", prevLastAcceptedCount), zap.Int("new", newLastAcceptedCount)) + + // Ensure we have collected last accepted blocks that are not our own last accepted block + // for at least 80% stake of the total weight we are connected to. + + totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeight2lastAccepted.totalWeight() + if err != nil { + sd.log.Error("Failed computing total weight", zap.Error(err)) + return snapshot{}, false + } + + ratio := float64(totalWeightWeKnowItsLastAcceptedBlock) / float64(totalValidatorWeight) + + if ratio < 0.8 { + sd.log.Trace("Most stake we're connected to has the same height as we do", + zap.Float64("ratio", ratio)) + return snapshot{}, false + } + + snap := snapshot{ + nodeWeights2Blocks: nodeWeight2lastAccepted, + totalValidatorWeight: totalValidatorWeight, + } + + if sd.haveWeFailedCatchingUp(snap) { + return snap, true + } + + return snapshot{}, false +} + +func (sd *stragglerDetector) getNetworkInfo() (nodeWeights2Blocks, uint64, uint64) { + ratio := sd.connectedPercent() + if ratio < sd.minConfirmationThreshold { + // We don't know for sure whether we're behind or not. + // Even if we're behind, it's pointless to act before we have established + // connectivity with enough validators. + sd.log.Verbo("not enough connected stake to determine network info", zap.Float64("ratio", ratio)) + return nil, 0, 0 + } + + validators := nodeWeights(sd.connectedValidators().List()) + + nodeWeight2lastAccepted := make(nodeWeights2Blocks, len(validators)) + + for _, vdr := range validators { + lastAccepted, ok := sd.lastAcceptedByNodeID(vdr.Node) + if !ok { + continue + } + nodeWeight2lastAccepted[vdr] = lastAccepted + } + + totalValidatorWeight, err := validators.totalWeight() + if err != nil { + sd.log.Error("Failed computing total weight", zap.Error(err)) + return nil, 0, 0 + } + + totalWeightWeKnowItsLastAcceptedBlock, err := nodeWeight2lastAccepted.totalWeight() + if err != nil { + sd.log.Error("Failed computing total weight", zap.Error(err)) + return nil, 0, 0 + } + + if totalValidatorWeight == 0 { + sd.log.Trace("Connected to zero weight") + return nil, 0, 0 + } + + ratio = float64(totalWeightWeKnowItsLastAcceptedBlock) / float64(totalValidatorWeight) + + // Ensure we have collected last accepted blocks for at least 80% stake of the total weight we are connected to. + if ratio < 0.8 { + sd.log.Trace("Not collected enough information about last accepted blocks for the validators we are connected to", + zap.Float64("ratio", ratio)) + return nil, 0, 0 + } + return nodeWeight2lastAccepted, totalValidatorWeight, totalWeightWeKnowItsLastAcceptedBlock +} + +type snapshot struct { + totalValidatorWeight uint64 + nodeWeights2Blocks nodeWeights2Blocks +} + +func (s snapshot) isEmpty() bool { + return s.totalValidatorWeight == 0 || len(s.nodeWeights2Blocks) == 0 +} + +type nodeWeights2Blocks map[ids.NodeWeight]ids.ID + +func (nw2b nodeWeights2Blocks) totalWeight() (uint64, error) { + return nodeWeights(maps.Keys(nw2b)).totalWeight() +} + +func dropHeight(f func() (ids.ID, uint64)) func() ids.ID { + return func() ids.ID { + id, _ := f() + return id + } +} + +type nodeWeights []ids.NodeWeight + +func (nws nodeWeights) totalWeight() (uint64, error) { + var weight uint64 + for _, nw := range nws { + newWeight, err := safemath.Add(weight, nw.Weight) + if err != nil { + return 0, fmt.Errorf("cumulative weight: %d, tried to add %d: %w", weight, nw.Weight, err) + } + weight = newWeight + } + return weight, nil +} diff --git a/snow/engine/snowman/straggler_detect_test.go b/snow/engine/snowman/straggler_detect_test.go new file mode 100644 index 000000000000..5d68711b25ff --- /dev/null +++ b/snow/engine/snowman/straggler_detect_test.go @@ -0,0 +1,376 @@ +// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package snowman + +import ( + "math" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/logging" + "github.com/ava-labs/avalanchego/utils/set" + + safemath "github.com/ava-labs/avalanchego/utils/math" +) + +func TestNodeWeights(t *testing.T) { + nws := nodeWeights{ + {Weight: 100}, + {Weight: 50}, + } + + total, err := nws.totalWeight() + require.NoError(t, err) + require.Equal(t, uint64(150), total) +} + +func TestNodeWeightsOverflow(t *testing.T) { + nws := nodeWeights{ + {Weight: math.MaxUint64 - 100}, + {Weight: 110}, + } + + total, err := nws.totalWeight() + require.ErrorIs(t, err, safemath.ErrOverflow) + require.Equal(t, uint64(0), total) +} + +func TestNodeWeights2Blocks(t *testing.T) { + nw2b := nodeWeights2Blocks{ + ids.NodeWeight{Weight: 5}: ids.Empty, + ids.NodeWeight{Weight: 10}: ids.Empty, + } + + total, err := nw2b.totalWeight() + require.NoError(t, err) + require.Equal(t, uint64(15), total) +} + +func TestGetNetworkSnapshot(t *testing.T) { + n1, err := ids.NodeIDFromString("NodeID-N5gc5soT3Gpr98NKpqvQQG2SgGrVPL64w") + require.NoError(t, err) + + n2, err := ids.NodeIDFromString("NodeID-NpagUxt6KQiwPch9Sd4osv8kD1TZnkjdk") + require.NoError(t, err) + + connectedValidators := func(s []ids.NodeWeight) func() set.Set[ids.NodeWeight] { + return func() set.Set[ids.NodeWeight] { + var set set.Set[ids.NodeWeight] + for _, nw := range s { + set.Add(nw) + } + return set + } + } + + for _, testCase := range []struct { + description string + lastAccepted ids.ID + lastAcceptedFromNodes map[ids.NodeID]ids.ID + processing map[ids.ID]struct{} + connectedValidators func() set.Set[ids.NodeWeight] + connectedPercent float64 + expectedSnapshot snapshot + expectedOK bool + expectedLogged string + }{ + { + description: "not enough connected validators", + connectedValidators: connectedValidators([]ids.NodeWeight{}), + expectedLogged: "not enough connected stake to determine network info", + }, + { + description: "connected to zero weight", + connectedPercent: 1.0, + connectedValidators: connectedValidators([]ids.NodeWeight{}), + expectedLogged: "Connected to zero weight", + }, + { + description: "not enough info", + connectedPercent: 1.0, + connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 1, Node: n1}, {Weight: 999999, Node: n2}}), + lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ + n1: {0x1}, + }, + expectedLogged: "Not collected enough information about last accepted blocks", + }, + { + description: "we're in sync", + connectedPercent: 1.0, + connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, Node: n1}}), + lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ + n1: {0x1}, + }, + lastAccepted: ids.ID{0x1}, + expectedLogged: "Most stake we're connected to has the same height as we do", + }, + { + description: "we're behind", + connectedPercent: 1.0, + connectedValidators: connectedValidators([]ids.NodeWeight{{Weight: 999999, Node: n1}}), + lastAcceptedFromNodes: map[ids.NodeID]ids.ID{ + n1: {0x1}, + }, + processing: map[ids.ID]struct{}{{0x1}: {}}, + lastAccepted: ids.ID{0x0}, + expectedSnapshot: snapshot{totalValidatorWeight: 999999, nodeWeights2Blocks: nodeWeights2Blocks{ + ids.NodeWeight{Node: n1, Weight: 999999}: {0x1}, + }}, + expectedOK: true, + }, + } { + t.Run(testCase.description, func(t *testing.T) { + var buff logBuffer + log := logging.NewLogger("", logging.NewWrappedCore(logging.Verbo, &buff, logging.Plain.ConsoleEncoder())) + + sd := newStragglerDetector(nil, log, 0.75, + func() (ids.ID, uint64) { + return testCase.lastAccepted, 0 + }, + testCase.connectedValidators, func() float64 { return testCase.connectedPercent }, + func(id ids.ID) bool { + _, ok := testCase.processing[id] + return ok + }, + func(vdr ids.NodeID) (ids.ID, bool) { + id, ok := testCase.lastAcceptedFromNodes[vdr] + return id, ok + }) + + snapshot, ok := sd.getNetworkSnapshot() + require.Equal(t, testCase.expectedSnapshot, snapshot) + require.Equal(t, testCase.expectedOK, ok) + require.Contains(t, buff.String(), testCase.expectedLogged) + }) + } +} + +func TestFailedCatchingUp(t *testing.T) { + n1, err := ids.NodeIDFromString("NodeID-N5gc5soT3Gpr98NKpqvQQG2SgGrVPL64w") + require.NoError(t, err) + + n2, err := ids.NodeIDFromString("NodeID-NpagUxt6KQiwPch9Sd4osv8kD1TZnkjdk") + require.NoError(t, err) + + for _, testCase := range []struct { + description string + lastAccepted ids.ID + lastAcceptedFromNodes map[ids.NodeID]ids.ID + processing map[ids.ID]struct{} + connectedValidators []ids.NodeWeight + connectedPercent float64 + input snapshot + expected bool + expectedLogged string + }{ + { + description: "stake overflow", + input: snapshot{ + nodeWeights2Blocks: nodeWeights2Blocks{ + ids.NodeWeight{Node: n1, Weight: math.MaxUint64 - 10}: ids.ID{0x1}, + ids.NodeWeight{Node: n2, Weight: 11}: ids.ID{0x2}, + }, + }, + processing: map[ids.ID]struct{}{ + {0x1}: {}, + {0x2}: {}, + }, + expectedLogged: "Cumulative weight overflow", + }, + { + description: "Straggling behind stake minority", + input: snapshot{ + totalValidatorWeight: 100, nodeWeights2Blocks: nodeWeights2Blocks{ + ids.NodeWeight{Node: n1, Weight: 25}: ids.ID{0x1}, + ids.NodeWeight{Node: n2, Weight: 50}: ids.ID{0x2}, + }, + }, + processing: map[ids.ID]struct{}{ + {0x1}: {}, + {0x2}: {}, + }, + expectedLogged: "Nodes ahead of us", + }, + { + description: "Straggling behind stake majority", + input: snapshot{ + totalValidatorWeight: 100, nodeWeights2Blocks: nodeWeights2Blocks{ + ids.NodeWeight{Node: n1, Weight: 26}: ids.ID{0x1}, + ids.NodeWeight{Node: n2, Weight: 50}: ids.ID{0x2}, + }, + }, + processing: map[ids.ID]struct{}{ + {0x1}: {}, + {0x2}: {}, + }, + expectedLogged: "We are straggling behind", + expected: true, + }, + { + description: "In sync with the majority", + input: snapshot{ + totalValidatorWeight: 100, nodeWeights2Blocks: nodeWeights2Blocks{ + ids.NodeWeight{Node: n1, Weight: 75}: ids.ID{0x1}, + ids.NodeWeight{Node: n2, Weight: 25}: ids.ID{0x2}, + }, + }, + processing: map[ids.ID]struct{}{ + {0x2}: {}, + }, + expectedLogged: "Nodes ahead of us", + }, + } { + t.Run(testCase.description, func(t *testing.T) { + var buff logBuffer + log := logging.NewLogger("", logging.NewWrappedCore(logging.Verbo, &buff, logging.Plain.ConsoleEncoder())) + + sd := newStragglerDetector(nil, log, 0.75, + func() (ids.ID, uint64) { + return testCase.lastAccepted, 0 + }, + func() set.Set[ids.NodeWeight] { + var set set.Set[ids.NodeWeight] + for _, nw := range testCase.connectedValidators { + set.Add(nw) + } + return set + }, func() float64 { return testCase.connectedPercent }, + func(id ids.ID) bool { + _, ok := testCase.processing[id] + return ok + }, + func(vdr ids.NodeID) (ids.ID, bool) { + id, ok := testCase.lastAcceptedFromNodes[vdr] + return id, ok + }) + + require.Equal(t, testCase.expected, sd.failedCatchingUp(testCase.input)) + require.Contains(t, buff.String(), testCase.expectedLogged) + }) + } +} + +func TestCheckIfWeAreStragglingBehind(t *testing.T) { + fakeClock := make(chan time.Time, 1) + + snapshots := make(chan snapshot, 1) + assertNoSnapshotsRemain := func() { + select { + case <-snapshots: + require.Fail(t, "Should not have any snapshots in standby") + default: + } + } + nonEmptySnap := snapshot{ + totalValidatorWeight: 100, + nodeWeights2Blocks: nodeWeights2Blocks{ + ids.NodeWeight{Weight: 100}: ids.Empty, + }, + } + + var haveWeFailedCatchingUpReturns bool + + var buff logBuffer + log := logging.NewLogger("", logging.NewWrappedCore(logging.Verbo, &buff, logging.Plain.ConsoleEncoder())) + + sd := stragglerDetector{ + stragglerDetectorConfig: stragglerDetectorConfig{ + minStragglerCheckInterval: time.Second, + getTime: func() time.Time { + now := <-fakeClock + return now + }, + log: log, + getSnapshot: func() (snapshot, bool) { + s := <-snapshots + return s, !s.isEmpty() + }, + haveWeFailedCatchingUp: func(_ snapshot) bool { + return haveWeFailedCatchingUpReturns + }, + }, + } + + fakeTime := time.Now() + + for _, testCase := range []struct { + description string + timeAdvanced time.Duration + evalExtraAssertions func() + expectedStragglingTime time.Duration + snapshotsRead []snapshot + haveWeFailedCatchingUpReturns bool + }{ + { + description: "First invocation only sets the time", + evalExtraAssertions: func() {}, + }, + { + description: "Should not check yet, as it is not time yet", + timeAdvanced: time.Millisecond * 500, + evalExtraAssertions: func() {}, + }, + { + description: "Advance time some more, so now we should check", + timeAdvanced: time.Millisecond * 501, + snapshotsRead: []snapshot{{}}, + evalExtraAssertions: func() { + require.Contains(t, buff.String(), "No node snapshot obtained") + }, + }, + { + description: "Advance time some more to the first check where the snapshot isn't empty", + timeAdvanced: time.Second * 2, + snapshotsRead: []snapshot{nonEmptySnap}, + evalExtraAssertions: func() { + require.Empty(t, buff.String()) + }, + }, + { + description: "The next check returns we have failed catching up.", + timeAdvanced: time.Second * 2, + expectedStragglingTime: time.Second * 2, + haveWeFailedCatchingUpReturns: true, + evalExtraAssertions: func() { + require.Empty(t, sd.prevSnapshot) + }, + }, + { + description: "The third snapshot is due to a fresh check", + timeAdvanced: time.Second * 2, + snapshotsRead: []snapshot{nonEmptySnap}, + // We carry over the total straggling time from previous testCase to this check, + // as we need the next check to nullify it. + expectedStragglingTime: time.Second * 2, + evalExtraAssertions: func() {}, + }, + { + description: "The fourth check returns we have succeeded in catching up", + timeAdvanced: time.Second * 2, + evalExtraAssertions: func() {}, + }, + } { + t.Run(testCase.description, func(t *testing.T) { + fakeTime = fakeTime.Add(testCase.timeAdvanced) + fakeClock <- fakeTime + + // Load the snapshot expected to be retrieved in this testCase, if applicable. + if len(testCase.snapshotsRead) > 0 { + snapshots <- testCase.snapshotsRead[0] + } + + haveWeFailedCatchingUpReturns = testCase.haveWeFailedCatchingUpReturns + require.Equal(t, testCase.expectedStragglingTime, sd.CheckIfWeAreStragglingBehind()) + testCase.evalExtraAssertions() + + // Cleanup the log buffer, and make sure no snapshots remain for next testCase. + buff.Reset() + assertNoSnapshotsRemain() + haveWeFailedCatchingUpReturns = false + }) + } +} diff --git a/snow/networking/handler/health.go b/snow/networking/handler/health.go index 0dbcb844fb95..fbbc8113e2d1 100644 --- a/snow/networking/handler/health.go +++ b/snow/networking/handler/health.go @@ -66,5 +66,13 @@ func (h *handler) getDisconnectedValidators() set.Set[ids.NodeID] { connectedVdrs := h.peerTracker.ConnectedValidators() // vdrs - connectedVdrs is equal to the disconnectedVdrs vdrs.Difference(connectedVdrs) - return vdrs + return trimWeights(vdrs) +} + +func trimWeights(weights set.Set[ids.NodeWeight]) set.Set[ids.NodeID] { + var res set.Set[ids.NodeID] + for _, nw := range weights.List() { + res.Add(nw.Node) + } + return res }