Skip to content

Commit

Permalink
Bug in oob error rate vs number of trees result
Browse files Browse the repository at this point in the history
  • Loading branch information
fradav committed Aug 27, 2020
1 parent 1c0abae commit e45d3eb
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
4 changes: 4 additions & 0 deletions src/ForestOnline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ void ForestOnline::init(std::string dependent_variable_name, MemoryMode memory_m
if (!prediction_mode && order_snps) {
data->orderSnpLevels(dependent_variable_name, (importance_mode == IMP_GINI_CORRECTED));
}

tree_order = std::vector<size_t>(num_trees);

}

void ForestOnline::run(bool verbose, bool compute_oob_error) {
Expand Down Expand Up @@ -602,6 +605,7 @@ void ForestOnline::growTreesInThread(uint thread_idx, std::vector<double>* varia
trees[i]->predict(predict_data,false);
predictInternal(i);
mutex.lock();
tree_order[progress] = i;
++progress;
if (verbose_out) {
#ifdef PYTHON_OUTPUT
Expand Down
2 changes: 1 addition & 1 deletion src/ForestOnline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class ForestOnline {
PredictionType prediction_type;
uint num_random_splits;
uint max_depth;

std::vector<size_t> tree_order;
// MAXSTAT splitrule
double alpha;
double minprop;
Expand Down
6 changes: 4 additions & 2 deletions src/ForestOnlineClassification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ void ForestOnlineClassification::calculateAfterGrow(size_t tree_idx, bool oob) {
mutex_post.lock();
++class_counts[sampleID][res];
mutex_post.unlock();
if (!class_counts[sample_idx].empty())
if (!class_counts[sampleID].empty())
to_add += (mostFrequentValue(class_counts[sampleID], random_number_generator) == data->get(sampleID,dependent_varID)) ? 0.0 : 1.0;
}
predictions[2][0][tree_idx] += to_add/static_cast<double>(numOOB);
Expand Down Expand Up @@ -237,7 +237,9 @@ void ForestOnlineClassification::computePredictionErrorInternal()
for(auto sample_idx = 0; sample_idx < predict_data->getNumRows(); sample_idx++) {
predictions[1][0][sample_idx] = mostFrequentValue(class_count[sample_idx], random_number_generator);
}

std::vector<double> sort_oob_trees(num_trees);
for(auto i = 0; i < num_trees; i++) sort_oob_trees[i] = predictions[2][0][tree_order[i]];
predictions[2][0] = sort_oob_trees;
}

// #nocov start
Expand Down

0 comments on commit e45d3eb

Please sign in to comment.