Skip to content

Commit

Permalink
More checks.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 16, 2024
1 parent 53b8d4a commit 870033b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 5 deletions.
4 changes: 2 additions & 2 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,8 @@ class RegTree : public Model {
if (!func(nidx)) {
return;
}
auto left = self[nidx].LeftChild();
auto right = self[nidx].RightChild();
auto left = self.LeftChild(nidx);
auto right = self.RightChild(nidx);
if (left != RegTree::kInvalidNodeId) {
nodes.push(left);
}
Expand Down
3 changes: 2 additions & 1 deletion src/tree/multi_target_tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,9 @@ void MultiTargetTree::Expand(bst_node_t nidx, bst_feature_t split_idx, float spl
split_index_.resize(n);
split_index_[nidx] = split_idx;

split_conds_.resize(n);
split_conds_.resize(n, std::numeric_limits<float>::quiet_NaN());
split_conds_[nidx] = split_cond;

default_left_.resize(n);
default_left_[nidx] = static_cast<std::uint8_t>(default_left);

Expand Down
28 changes: 26 additions & 2 deletions tests/cpp/tree/test_tree_stat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,29 +96,53 @@ class TestSplitWithEta : public ::testing::Test {
updater->Configure({});

auto grad = GenerateRandomGradients(ctx, Xy->Info().num_row_, n_targets);
CHECK_EQ(grad.Shape(1), n_targets);
tree::TrainParam param;
param.Init(Args{{"learning_rate", std::to_string(eta)}});
HostDeviceVector<bst_node_t> position;

updater->Update(&param, &grad, Xy.get(), common::Span{&position, 1}, trees);
CHECK_EQ(tree->NumTargets(), n_targets);
if (n_targets > 1) {
CHECK(tree->IsMultiTarget());
}
return tree;
};

auto eta_ratio = 8.0f;
auto p_tree0 = gen_tree(0.1f);
auto p_tree1 = gen_tree(0.8f);
auto p_tree1 = gen_tree(0.1f * eta_ratio);
// Just to make sure we are not testing a stump.
CHECK_GE(p_tree0->NumExtraNodes(), 32);

bst_node_t n_nodes{0};
p_tree0->WalkTree([&](bst_node_t nidx) {
if (p_tree0->IsLeaf(nidx)) {
CHECK(p_tree1->IsLeaf(nidx));
CHECK_EQ(p_tree0->SplitCond(nidx) * 8.0f, p_tree1->SplitCond(nidx));
if (p_tree0->IsMultiTarget()) {
CHECK(p_tree1->IsMultiTarget());
auto leaf_0 = p_tree0->GetMultiTargetTree()->LeafValue(nidx);
auto leaf_1 = p_tree1->GetMultiTargetTree()->LeafValue(nidx);
CHECK_EQ(leaf_0.Size(), leaf_1.Size());
for (std::size_t i = 0; i < leaf_0.Size(); ++i) {
CHECK_EQ(leaf_0(i) * eta_ratio, leaf_1(i));
}
CHECK(std::isnan(p_tree0->SplitCond(nidx)));
CHECK(std::isnan(p_tree1->SplitCond(nidx)));
} else {
// NON-mt tree reuses split cond for leaf value.
auto leaf_0 = p_tree0->SplitCond(nidx);
auto leaf_1 = p_tree1->SplitCond(nidx);
CHECK_EQ(leaf_0 * eta_ratio, leaf_1);
}
} else {
CHECK(!p_tree1->IsLeaf(nidx));
CHECK_EQ(p_tree0->SplitCond(nidx), p_tree1->SplitCond(nidx));
}
n_nodes++;
return true;
});
ASSERT_EQ(n_nodes, p_tree0->NumExtraNodes() + 1);
}
};

Expand Down

0 comments on commit 870033b

Please sign in to comment.