Skip to content

Commit

Permalink
TreeEnsemble speed up (microsoft#17449)
Browse files Browse the repository at this point in the history
### Description
This PR proposes a change that should speed up inference for the
TreeEnsemble* kernels. Previously, when traversing a decision tree, the
`TreeNodeElement` pointer would be incremented or decremented to the
appropriate child node - I assume this was because the
`truenode_inc_or_first_weight` and `falsenode_inc_or_n_weights` member
were overloaded for two purposes.

In this PR, we now assign the true branch pointer. We also initialise
`nodes_` in a pre-order traversal which means that the false branch's
position can be resolved statically and does not need to be stored.

I observe the following speed ups. The benchmarks used are derived from
those in https://github.com/siboehm/lleaves/tree/master/benchmarks and
the baseline is the main branch.

NYC Dataset
--------------
| Number of threads | Baseline | Pointer assignment | Pre-ordered
initialisation | Pointer assignment % improvement | Pre-ordered
initialisation % improvement |

|--------------------:|-----------:|---------------------:|-----------------------------:|-----------------------------------:|-------------------------------------------:|
| 1 | 176.539 | 155.709 | 145.119 | 11.7989 | 17.7976 |
| 4 | 59.9015 | 51.9652 | 50.0884 | 13.2488 | 16.382 |
| 8 | 34.5561 | 31.3024 | 28.2535 | 9.41581 | 18.2387 |

Airline Dataset
---------------

| Number of threads | Baseline | Pointer assignment | Pre-ordered
initialisation | Pointer assignment % improvement | Pre-ordered
initialisation % improvement |

|--------------------:|-----------:|---------------------:|-----------------------------:|-----------------------------------:|-------------------------------------------:|
| 1 | 2127.34 | 1389.7 | 920.373 | 34.6745 | 56.736 |
| 4 | 723.307 | 481.634 | 310.618 | 33.4122 | 57.0558 |
| 8 | 420.722 | 278.397 | 185.265 | 33.8286 | 55.9651 |

mtpl2 Dataset
--------------

| Number of threads | Baseline | Pointer assignment | Pre-ordered
initialisation | Pointer assignment % improvement | Pre-ordered
initialisation % improvement |

|--------------------:|-----------:|---------------------:|-----------------------------:|-----------------------------------:|-------------------------------------------:|
| 1 | 1143.62 | 1020.04 | 998.171 | 10.8055 | 13.0988 |
| 4 | 386.153 | 339.905 | 328.061 | 11.9764 | 14.3729 |
| 8 | 225.995 | 200.665 | 199.057 | 11.2084 | 13.4408 |

These were run using an M2 Pro with 16GB of RAM. All times are in
milliseconds and averages over 10 runs with a batch size of 100,000.

### Motivation and Context
Performance improvements.
  • Loading branch information
adityagoel4512 authored Sep 12, 2023
1 parent 65249f4 commit db558ef
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 165 deletions.
49 changes: 28 additions & 21 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,31 +64,38 @@ enum MissingTrack : uint8_t {
kFalse = 0
};

template <typename T>
struct TreeNodeElement;

template <typename T>
union PtrOrWeight {
TreeNodeElement<T>* ptr;
struct WeightData {
int32_t weight;
int32_t n_weights;
} weight_data;
};

template <typename T>
struct TreeNodeElement {
int feature_id;

// Stores the node threshold or the weights if the tree has one target.
T value_or_unique_weight;

// onnx specification says hitrates is used to store information about the node,
// The onnx specification says hitrates is used to store information about the node,
// but this information is not used for inference.
// T hitrates;

// True node, false node are obtained by computing `this + truenode_inc_or_first_weight`,
// `this + falsenode_inc_or_n_weights` if the node is not a leaf.
// In case of a leaf, these attributes are used to indicate the position of the weight
// in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one,
// the weight is also stored in `value_or_unique_weight`.
// This implementation assumes a tree has less than 2^31 nodes,
// and the total number of leave in the set of trees is below 2^31.
// A node cannot point to itself.
int32_t truenode_inc_or_first_weight;
// In case of a leaf, the following attribute indicates the number of weights
// in array `TreeEnsembleCommon::weights_`. If not a leaf, it indicates
// `this + falsenode_inc_or_n_weights` is the false node.
// A node cannot point to itself.
int32_t falsenode_inc_or_n_weights;
// PtrOrWeight acts as a tagged union, with the "tag" being whether the node is a leaf or not (see `is_not_leaf`).

// If it is not a leaf, it is a pointer to the true child node when traversing the decision tree. The false branch is
// always 1 position away from the TreeNodeElement in practice in `TreeEnsembleCommon::nodes_` so it is not stored.

// If it is a leaf, it contains `weight` and `n_weights` attributes which are used to indicate the position of the
// weight in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one, the weight is also
// stored in `value_or_unique_weight`.
PtrOrWeight<T> truenode_or_weight;
uint8_t flags;

inline NODE_MODE mode() const { return NODE_MODE(flags & 0xF); }
Expand Down Expand Up @@ -189,8 +196,8 @@ class TreeAggregatorSum : public TreeAggregator<InputType, ThresholdType, Output
void ProcessTreeNodePrediction(InlinedVector<ScoreValue<ThresholdType>>& predictions,
const TreeNodeElement<ThresholdType>& root,
gsl::span<const SparseValue<ThresholdType>> weights) const {
auto it = weights.begin() + root.truenode_inc_or_first_weight;
for (int32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) {
auto it = weights.begin() + root.truenode_or_weight.weight_data.weight;
for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) {
ORT_ENFORCE(it->i < (int64_t)predictions.size());
predictions[onnxruntime::narrow<size_t>(it->i)].score += it->value;
predictions[onnxruntime::narrow<size_t>(it->i)].has_score = 1;
Expand Down Expand Up @@ -292,8 +299,8 @@ class TreeAggregatorMin : public TreeAggregator<InputType, ThresholdType, Output
void ProcessTreeNodePrediction(InlinedVector<ScoreValue<ThresholdType>>& predictions,
const TreeNodeElement<ThresholdType>& root,
gsl::span<const SparseValue<ThresholdType>> weights) const {
auto it = weights.begin() + root.truenode_inc_or_first_weight;
for (int32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) {
auto it = weights.begin() + root.truenode_or_weight.weight_data.weight;
for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) {
predictions[onnxruntime::narrow<size_t>(it->i)].score =
(!predictions[onnxruntime::narrow<size_t>(it->i)].has_score || it->value < predictions[onnxruntime::narrow<size_t>(it->i)].score)
? it->value
Expand Down Expand Up @@ -349,8 +356,8 @@ class TreeAggregatorMax : public TreeAggregator<InputType, ThresholdType, Output
void ProcessTreeNodePrediction(InlinedVector<ScoreValue<ThresholdType>>& predictions,
const TreeNodeElement<ThresholdType>& root,
gsl::span<const SparseValue<ThresholdType>> weights) const {
auto it = weights.begin() + root.truenode_inc_or_first_weight;
for (int32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) {
auto it = weights.begin() + root.truenode_or_weight.weight_data.weight;
for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) {
predictions[onnxruntime::narrow<size_t>(it->i)].score =
(!predictions[onnxruntime::narrow<size_t>(it->i)].has_score || it->value > predictions[onnxruntime::narrow<size_t>(it->i)].score)
? it->value
Expand Down
Loading

0 comments on commit db558ef

Please sign in to comment.