diff --git a/internal/tree/treeconfig.go b/internal/tree/treeconfig.go index b458ff5c..616796e2 100644 --- a/internal/tree/treeconfig.go +++ b/internal/tree/treeconfig.go @@ -128,21 +128,18 @@ func (t *Tree) PeersOf(nodeID hotstuff.ID) []hotstuff.ID { return t.ChildrenOf(parent) } -// SubTree returns all the nodes of its subtree. +// SubTree returns all subtree nodes of this tree's replica. func (t *Tree) SubTree() []hotstuff.ID { - nodeID := t.id - subTreeNodes := make([]hotstuff.ID, 0) - children := t.ChildrenOfNode(nodeID) - queue := make([]hotstuff.ID, 0) - queue = append(queue, children...) - subTreeNodes = append(subTreeNodes, children...) + children := t.ChildrenOf(t.id) if len(children) == 0 { - return subTreeNodes + return nil } + subTreeNodes := slices.Clone(children) + queue := slices.Clone(children) for len(queue) > 0 { child := queue[0] queue = queue[1:] - children := t.ChildrenOfNode(child) + children := t.ChildrenOf(child) subTreeNodes = append(subTreeNodes, children...) queue = append(queue, children...) } diff --git a/internal/tree/treeconfig_test.go b/internal/tree/treeconfig_test.go index 7a949535..86aff96a 100644 --- a/internal/tree/treeconfig_test.go +++ b/internal/tree/treeconfig_test.go @@ -93,10 +93,10 @@ func TestTreeAPIWithInitializeWithPIDs(t *testing.T) { if tree.TreeHeight() != test.height { t.Errorf("Expected height %d, got %d", test.height, tree.TreeHeight()) } - gotChildren := tree.ChildrenOf() + gotChildren := tree.NodeChildren() sort.Slice(gotChildren, func(i, j int) bool { return gotChildren[i] < gotChildren[j] }) if len(gotChildren) != len(test.children) || !slices.Equal(gotChildren, test.children) { - t.Errorf("Expected %v, got %v", test.children, tree.ChildrenOf()) + t.Errorf("Expected %v, got %v", test.children, tree.NodeChildren()) } subTree := tree.SubTree() sort.Slice(subTree, func(i, j int) bool { return subTree[i] < subTree[j] }) @@ -112,8 +112,8 @@ func TestTreeAPIWithInitializeWithPIDs(t *testing.T) { if tree.IsRoot(test.id) != test.isRoot { t.Errorf("Expected %t, got %t", test.isRoot, tree.IsRoot(test.id)) } - if tree.GetHeight() != test.replicaHeight { - t.Errorf("Expected %d, got %d", test.replicaHeight, tree.GetHeight()) + if tree.NodeHeight() != test.replicaHeight { + t.Errorf("Expected %d, got %d", test.replicaHeight, tree.NodeHeight()) } gotPeers := tree.PeersOf(test.id) sort.Slice(gotPeers, func(i, j int) bool { return gotPeers[i] < gotPeers[j] }) @@ -130,7 +130,7 @@ func benchmarkGetChildren(size int, bf int, b *testing.B) { } tree := CreateTree(1, bf, ids) for i := 0; i < b.N; i++ { - tree.ChildrenOf() + tree.NodeChildren() } }