diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index b376ae03e9cdd..eeb8ebb3ccefe 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -453,6 +453,7 @@ Do not modify directly.*
|SVMClassifier|*in* X:**T1**
*out* Y:**T2**
*out* Z:**tensor(float)**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int64), tensor(string)|
|SVMRegressor|*in* X:**T**
*out* Y:**tensor(float)**|1+|**T** = tensor(float)|
|Scaler|*in* X:**T**
*out* Y:**tensor(float)**|1+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
+|TreeEnsemble|*in* X:**T**
*out* Y:**T**|5+|**T** = tensor(double), tensor(float)|
|TreeEnsembleClassifier|*in* X:**T1**
*out* Y:**T2**
*out* Z:**tensor(float)**|3+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int64), tensor(string)|
|||[1, 2]|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int64), tensor(string)|
|TreeEnsembleRegressor|*in* X:**T**
*out* Y:**tensor(float)**|3+|**T** = tensor(double), tensor(float)|
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index 65eeb4b84e193..0499a15e1df0a 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -2925,6 +2925,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, int32_t, TreeEnsembleClassifier);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, float, TreeEnsembleRegressor);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double, TreeEnsembleRegressor);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, float, TreeEnsemble);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, double, TreeEnsemble);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, float_string, LabelEncoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, string_float, LabelEncoder);
@@ -3043,6 +3045,10 @@ Status RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) {
TreeEnsembleRegressor)>,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cpu/ml/ml_common.h b/onnxruntime/core/providers/cpu/ml/ml_common.h
index 2f4ebeabe043e..3359b2a69fe83 100644
--- a/onnxruntime/core/providers/cpu/ml/ml_common.h
+++ b/onnxruntime/core/providers/cpu/ml/ml_common.h
@@ -20,44 +20,48 @@ enum class OUTPUT_MODE {
ALL_SCORES
};
-enum NODE_MODE : uint8_t {
- LEAF = 1,
- BRANCH_LEQ = 2,
- BRANCH_LT = 4,
- BRANCH_GTE = 6,
- BRANCH_GT = 8,
- BRANCH_EQ = 10,
- BRANCH_NEQ = 12
+enum NODE_MODE_ONNX : uint8_t {
+ BRANCH_LEQ = 0,
+ BRANCH_LT = 1,
+ BRANCH_GTE = 2,
+ BRANCH_GT = 3,
+ BRANCH_EQ = 4,
+ BRANCH_NEQ = 5,
+ BRANCH_MEMBER = 6,
+ LEAF = 7,
};
-static inline NODE_MODE MakeTreeNodeMode(const std::string& input) {
+static inline NODE_MODE_ONNX MakeTreeNodeMode(const std::string& input) {
if (input == "BRANCH_LEQ") {
- return NODE_MODE::BRANCH_LEQ;
+ return NODE_MODE_ONNX::BRANCH_LEQ;
}
if (input == "LEAF") {
- return NODE_MODE::LEAF;
+ return NODE_MODE_ONNX::LEAF;
}
if (input == "BRANCH_LT") {
- return NODE_MODE::BRANCH_LT;
+ return NODE_MODE_ONNX::BRANCH_LT;
}
if (input == "BRANCH_GTE") {
- return NODE_MODE::BRANCH_GTE;
+ return NODE_MODE_ONNX::BRANCH_GTE;
}
if (input == "BRANCH_GT") {
- return NODE_MODE::BRANCH_GT;
+ return NODE_MODE_ONNX::BRANCH_GT;
}
if (input == "BRANCH_EQ") {
- return NODE_MODE::BRANCH_EQ;
+ return NODE_MODE_ONNX::BRANCH_EQ;
}
- return NODE_MODE::BRANCH_NEQ;
+ if (input == "BRANCH_MEMBER") {
+ return NODE_MODE_ONNX::BRANCH_MEMBER;
+ }
+ return NODE_MODE_ONNX::BRANCH_NEQ;
}
-enum class POST_EVAL_TRANSFORM {
- NONE,
- LOGISTIC,
- SOFTMAX,
- SOFTMAX_ZERO,
- PROBIT
+enum class POST_EVAL_TRANSFORM : int64_t {
+ NONE = 0,
+ LOGISTIC = 1,
+ SOFTMAX = 2,
+ SOFTMAX_ZERO = 3,
+ PROBIT = 4
};
static inline POST_EVAL_TRANSFORM MakeTransform(const std::string& input) {
@@ -76,11 +80,11 @@ static inline POST_EVAL_TRANSFORM MakeTransform(const std::string& input) {
return POST_EVAL_TRANSFORM::PROBIT;
}
-enum class AGGREGATE_FUNCTION {
- AVERAGE,
- SUM,
- MIN,
- MAX
+enum class AGGREGATE_FUNCTION : int64_t {
+ AVERAGE = 0,
+ SUM = 1,
+ MIN = 2,
+ MAX = 3
};
static inline AGGREGATE_FUNCTION MakeAggregateFunction(const std::string& input) {
diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble.cc b/onnxruntime/core/providers/cpu/ml/tree_ensemble.cc
new file mode 100644
index 0000000000000..3ff501d96b72d
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble.cc
@@ -0,0 +1,59 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cpu/ml/tree_ensemble.h"
+#include "core/providers/cpu/ml/tree_ensemble_helper.h"
+#include "core/common/inlined_containers_fwd.h"
+
+namespace onnxruntime {
+namespace ml {
+
+ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
+ TreeEnsemble,
+ 5,
+ float,
+ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()).MayInplace(0, 0),
+ TreeEnsemble);
+
+ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(
+ TreeEnsemble,
+ 5,
+ double,
+ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()).MayInplace(0, 0),
+ TreeEnsemble);
+
+template
+TreeEnsemble::TreeEnsemble(const OpKernelInfo& info) : OpKernel(info) {
+ if constexpr (std::is_same::value) {
+ p_tree_ensemble_ = std::make_unique>();
+ } else {
+ p_tree_ensemble_ = std::make_unique>();
+ }
+ ORT_THROW_IF_ERROR(p_tree_ensemble_->Init(info));
+}
+
+template
+Status TreeEnsemble::GetRemovableAttributes(InlinedVector& removable_attributes) const {
+ InlinedVector names{
+ "leaf_targetids", "leaf_weights", "membership_values", "nodes_falseleafs",
+ "nodes_falsenodeids", "nodes_featureids", "nodes_hitrates", "nodes_missing_value_tracks_true",
+ "nodes_modes", "nodes_splits", "nodes_trueleafs", "nodes_truenodeids"};
+ removable_attributes.swap(names);
+ return Status::OK();
+}
+
+template
+common::Status TreeEnsemble::Compute(OpKernelContext* context) const {
+ const auto* X = context->Input(0);
+ if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
+ if (X->Shape().NumDimensions() == 0) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
+ "Input shape needs to be at least a single dimension.");
+ }
+ int64_t N = X->Shape().NumDimensions() == 1 ? 1 : X->Shape()[0];
+ Tensor* Y = context->Output(0, {N, p_tree_ensemble_->get_target_or_class_count()});
+ return p_tree_ensemble_->compute(context, X, Y, NULL);
+}
+
+} // namespace ml
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble.h
new file mode 100644
index 0000000000000..697aae045a7e3
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble.h
@@ -0,0 +1,25 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "tree_ensemble_common.h"
+
+namespace onnxruntime {
+namespace ml {
+template
+class TreeEnsemble final : public OpKernel {
+ typedef T InputType; // input type
+ typedef float OutputType; // output type
+ public:
+ explicit TreeEnsemble(const OpKernelInfo& info);
+ common::Status Compute(OpKernelContext* context) const override;
+ Status GetRemovableAttributes(InlinedVector& removable_attributes) const override;
+
+ private:
+ // Pointer on one instance of
+ // detail::TreeEnsembleCommonV5
+ // where ThresholdType is defined after accessing the attributes.
+ std::unique_ptr p_tree_ensemble_;
+};
+} // namespace ml
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h
index b031a6f0cefa3..bf3fd37d10f5c 100644
--- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h
+++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h
@@ -78,6 +78,40 @@ union PtrOrWeight {
} weight_data;
};
+enum NODE_MODE_ORT : uint8_t {
+ LEAF = 1,
+ BRANCH_LEQ = 2,
+ BRANCH_LT = 4,
+ BRANCH_GTE = 6,
+ BRANCH_GT = 8,
+ BRANCH_EQ = 10,
+ BRANCH_NEQ = 12,
+ BRANCH_MEMBER = 14,
+};
+
+inline NODE_MODE_ORT Convert_NODE_MODE_ONNX_to_ORT(NODE_MODE_ONNX node_mode) {
+ switch (node_mode) {
+ case NODE_MODE_ONNX::LEAF:
+ return NODE_MODE_ORT::LEAF;
+ case NODE_MODE_ONNX::BRANCH_LEQ:
+ return NODE_MODE_ORT::BRANCH_LEQ;
+ case NODE_MODE_ONNX::BRANCH_LT:
+ return NODE_MODE_ORT::BRANCH_LT;
+ case NODE_MODE_ONNX::BRANCH_GTE:
+ return NODE_MODE_ORT::BRANCH_GTE;
+ case NODE_MODE_ONNX::BRANCH_GT:
+ return NODE_MODE_ORT::BRANCH_GT;
+ case NODE_MODE_ONNX::BRANCH_EQ:
+ return NODE_MODE_ORT::BRANCH_EQ;
+ case NODE_MODE_ONNX::BRANCH_NEQ:
+ return NODE_MODE_ORT::BRANCH_NEQ;
+ case NODE_MODE_ONNX::BRANCH_MEMBER:
+ return NODE_MODE_ORT::BRANCH_MEMBER;
+ default:
+ ORT_THROW("Unexpected value for node_mode");
+ };
+}
+
template
struct TreeNodeElement {
int feature_id;
@@ -98,10 +132,10 @@ struct TreeNodeElement {
// weight in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one, the weight is also
// stored in `value_or_unique_weight`.
PtrOrWeight truenode_or_weight;
- uint8_t flags;
+ NODE_MODE_ORT flags;
- inline NODE_MODE mode() const { return NODE_MODE(flags & 0xF); }
- inline bool is_not_leaf() const { return !(flags & NODE_MODE::LEAF); }
+ inline NODE_MODE_ORT mode() const { return NODE_MODE_ORT(flags & 0xF); }
+ inline bool is_not_leaf() const { return !(flags & NODE_MODE_ORT::LEAF); }
inline bool is_missing_track_true() const { return flags & MissingTrack::kTrue; }
};
diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h
new file mode 100644
index 0000000000000..d2d1ba9863ac7
--- /dev/null
+++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h
@@ -0,0 +1,321 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/common/inlined_containers.h"
+#include "core/common/common.h"
+#include "core/framework/op_kernel.h"
+#include "ml_common.h"
+#include "tree_ensemble_helper.h"
+#include
+
+namespace onnxruntime {
+namespace ml {
+namespace detail {
+
+inline bool _isnan_(float x) { return std::isnan(x); }
+inline bool _isnan_(double x) { return std::isnan(x); }
+inline bool _isnan_(int64_t) { return false; }
+inline bool _isnan_(int32_t) { return false; }
+
+template
+struct TreeEnsembleAttributesV3 {
+ TreeEnsembleAttributesV3() {}
+ TreeEnsembleAttributesV3(const OpKernelInfo& info, bool classifier) {
+#if !defined(ORT_MINIMAL_BUILD)
+ ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "base_values_as_tensor", base_values_as_tensor));
+ ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates_as_tensor", nodes_hitrates_as_tensor));
+ ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_values_as_tensor", nodes_values_as_tensor));
+ if (classifier) {
+ ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "class_weights_as_tensor", target_class_weights_as_tensor));
+ } else {
+ ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "target_weights_as_tensor", target_class_weights_as_tensor));
+ }
+#endif
+
+ aggregate_function = info.GetAttrOrDefault("aggregate_function", "SUM");
+ base_values = info.GetAttrsOrDefault("base_values");
+ nodes_falsenodeids = info.GetAttrsOrDefault("nodes_falsenodeids");
+ nodes_featureids = info.GetAttrsOrDefault("nodes_featureids");
+ nodes_missing_value_tracks_true = info.GetAttrsOrDefault("nodes_missing_value_tracks_true");
+
+ std::vector nodes_modes_string = info.GetAttrsOrDefault("nodes_modes");
+ nodes_modes.reserve(nodes_modes_string.size());
+ for (auto s : nodes_modes_string) {
+ nodes_modes.emplace_back(MakeTreeNodeMode(s));
+ }
+
+ nodes_nodeids = info.GetAttrsOrDefault("nodes_nodeids");
+ nodes_treeids = info.GetAttrsOrDefault("nodes_treeids");
+ nodes_truenodeids = info.GetAttrsOrDefault("nodes_truenodeids");
+ nodes_values = info.GetAttrsOrDefault("nodes_values");
+ post_transform = info.GetAttrOrDefault("post_transform", "NONE");
+
+ if (classifier) {
+ target_class_ids = info.GetAttrsOrDefault("class_ids");
+ target_class_nodeids = info.GetAttrsOrDefault("class_nodeids");
+ target_class_treeids = info.GetAttrsOrDefault("class_treeids");
+ target_class_weights = info.GetAttrsOrDefault("class_weights");
+ classlabels_strings = info.GetAttrsOrDefault("classlabels_strings");
+ classlabels_int64s = info.GetAttrsOrDefault("classlabels_int64s");
+ n_targets_or_classes = classlabels_strings.empty() ? classlabels_int64s.size()
+ : classlabels_strings.size();
+ } else {
+ n_targets_or_classes = info.GetAttrOrDefault("n_targets", 0);
+ target_class_ids = info.GetAttrsOrDefault("target_ids");
+ target_class_nodeids = info.GetAttrsOrDefault("target_nodeids");
+ target_class_treeids = info.GetAttrsOrDefault("target_treeids");
+ target_class_weights = info.GetAttrsOrDefault("target_weights");
+
+ ORT_ENFORCE(n_targets_or_classes > 0);
+ ORT_ENFORCE(nodes_falsenodeids.size() == nodes_featureids.size());
+ ORT_ENFORCE(nodes_falsenodeids.size() == nodes_modes_string.size());
+ ORT_ENFORCE(nodes_falsenodeids.size() == nodes_nodeids.size());
+ ORT_ENFORCE(nodes_falsenodeids.size() == nodes_treeids.size());
+ ORT_ENFORCE(nodes_falsenodeids.size() == nodes_truenodeids.size());
+ ORT_ENFORCE(nodes_falsenodeids.size() == nodes_values.size() ||
+ nodes_falsenodeids.size() == nodes_values_as_tensor.size());
+ ORT_ENFORCE(target_class_ids.size() == target_class_nodeids.size());
+ ORT_ENFORCE(target_class_ids.size() == target_class_treeids.size());
+ ORT_ENFORCE(target_class_weights.empty() || target_class_ids.size() == target_class_weights.size());
+ ORT_ENFORCE(base_values.empty() || base_values_as_tensor.empty());
+ ORT_ENFORCE(nodes_hitrates.empty() || nodes_hitrates_as_tensor.empty());
+ ORT_ENFORCE(nodes_values.empty() || nodes_values_as_tensor.empty());
+ ORT_ENFORCE(target_class_weights.empty() || target_class_weights_as_tensor.empty());
+ ORT_ENFORCE(nodes_modes_string.size() < std::numeric_limits::max());
+ }
+ }
+
+ std::string aggregate_function;
+ std::vector base_values;
+ std::vector base_values_as_tensor;
+ int64_t n_targets_or_classes;
+ std::vector nodes_falsenodeids;
+ std::vector nodes_featureids;
+ std::vector nodes_hitrates;
+ std::vector nodes_hitrates_as_tensor;
+ std::vector nodes_missing_value_tracks_true;
+ std::vector nodes_modes;
+ std::vector nodes_nodeids;
+ std::vector nodes_treeids;
+ std::vector nodes_truenodeids;
+ std::vector nodes_values;
+ std::vector nodes_values_as_tensor;
+ std::string post_transform;
+ std::vector target_class_ids;
+ std::vector target_class_nodeids;
+ std::vector target_class_treeids;
+ std::vector target_class_weights;
+ std::vector target_class_weights_as_tensor;
+ std::vector classlabels_strings;
+ std::vector classlabels_int64s;
+ std::vector class_labels;
+};
+
+template
+struct TreeEnsembleAttributesV5 {
+ TreeEnsembleAttributesV5() {}
+ TreeEnsembleAttributesV5(const OpKernelInfo& info) {
+#if !defined(ORT_MINIMAL_BUILD)
+ std::vector nodes_modes_i;
+ ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "leaf_weights", leaf_weights));
+ ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "membership_values", membership_values));
+ ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates", nodes_hitrates));
+ ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_modes", nodes_modes_i));
+ ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_splits", nodes_splits));
+ nodes_modes.reserve(nodes_modes.size());
+ for (auto i : nodes_modes_i) {
+ nodes_modes.push_back(static_cast(i));
+ }
+#else
+ // GetVectorAttrsOrDefault is not part of the minimal build.
+ // As a result, TreeEnsemble v5 cannot be available in this build.
+ ORT_THROW("TreeEnsemble(ai.onnx.ml==5) is not supported with the minimal build.");
+#endif
+
+ aggregate_function = info.GetAttrOrDefault("aggregate_function", 1);
+ leaf_targetids = info.GetAttrsOrDefault("leaf_targetids");
+ n_targets = info.GetAttrOrDefault("n_targets", 0);
+ nodes_falseleafs = info.GetAttrsOrDefault("nodes_falseleafs");
+ nodes_falsenodeids = info.GetAttrsOrDefault("nodes_falsenodeids");
+ nodes_featureids = info.GetAttrsOrDefault("nodes_featureids");
+ nodes_missing_value_tracks_true = info.GetAttrsOrDefault("nodes_missing_value_tracks_true");
+ nodes_trueleafs = info.GetAttrsOrDefault("nodes_trueleafs");
+ nodes_truenodeids = info.GetAttrsOrDefault("nodes_truenodeids");
+ post_transform = info.GetAttrOrDefault("post_transform", 0);
+ tree_roots = info.GetAttrsOrDefault("tree_roots");
+ }
+
+ void convert_to_v3(TreeEnsembleAttributesV3& output) const {
+ // Doing all transformations to get the old format.
+ output.n_targets_or_classes = n_targets;
+ output.aggregate_function = aggregateFunctionToString();
+ output.post_transform = postTransformToString();
+ std::vector> membership_values_by_id;
+ getMembershipValuesById(membership_values_by_id);
+ transformInputAllTrees(output, membership_values_by_id);
+ }
+
+ int64_t aggregate_function;
+ std::vector leaf_targetids;
+ std::vector leaf_weights;
+ std::vector membership_values;
+ int64_t n_targets;
+ std::vector nodes_falseleafs;
+ std::vector nodes_falsenodeids;
+ std::vector nodes_featureids;
+ std::vector nodes_hitrates;
+ std::vector nodes_missing_value_tracks_true;
+ std::vector nodes_modes;
+ std::vector nodes_splits;
+ std::vector nodes_trueleafs;
+ std::vector nodes_truenodeids;
+ int64_t post_transform;
+ std::vector tree_roots;
+
+ private:
+ // `membership_values` are seperated by NAN for different nodes
+ // It is more convenient to preserve the values for each node in a vector
+ // The vector would be empty for nodes that are not `BRANCH_MEMBER`
+ void getMembershipValuesById(std::vector>& membership_values_by_id) const {
+ membership_values_by_id.clear();
+ membership_values_by_id.reserve(nodes_modes.size());
+
+ size_t curr_id = 0;
+ for (const auto node_mode : nodes_modes) {
+ membership_values_by_id.emplace_back();
+ if (node_mode != NODE_MODE_ONNX::BRANCH_MEMBER) {
+ continue;
+ }
+
+ while (curr_id < membership_values.size() && !_isnan_(membership_values[curr_id])) {
+ membership_values_by_id.back().push_back(membership_values[curr_id++]);
+ }
+ curr_id++;
+ }
+ }
+
+ std::string aggregateFunctionToString() const {
+ switch (aggregate_function) {
+ case static_cast(AGGREGATE_FUNCTION::AVERAGE):
+ return "AVERAGE";
+ case static_cast(AGGREGATE_FUNCTION::SUM):
+ return "SUM";
+ case static_cast(AGGREGATE_FUNCTION::MIN):
+ return "MIN";
+ case static_cast(AGGREGATE_FUNCTION::MAX):
+ return "MAX";
+ default:
+ ORT_THROW("Unknown value for aggregate_function.");
+ }
+ }
+
+ std::string postTransformToString() const {
+ switch (post_transform) {
+ case static_cast(POST_EVAL_TRANSFORM::NONE):
+ return "NONE";
+ case static_cast(POST_EVAL_TRANSFORM::SOFTMAX):
+ return "SOFTMAX";
+ case static_cast(POST_EVAL_TRANSFORM::LOGISTIC):
+ return "LOGISTIC";
+ case static_cast(POST_EVAL_TRANSFORM::SOFTMAX_ZERO):
+ return "SOFTMAX_ZERO";
+ case static_cast(POST_EVAL_TRANSFORM::PROBIT):
+ return "PROBIT";
+ default:
+ ORT_THROW("Unknown value for post_transform.");
+ }
+ }
+
+ int64_t transformInputOneTree(
+ const size_t curr_id, const int64_t curr_treeid, const int64_t curr_nodeid, const size_t curr_membership_value_id,
+ const bool is_leaf, std::vector>& membership_values_by_id,
+ TreeEnsembleAttributesV3& output) const {
+ output.nodes_nodeids.push_back(curr_nodeid);
+ output.nodes_treeids.push_back(curr_treeid);
+
+ if (is_leaf) {
+ output.nodes_modes.push_back(NODE_MODE_ONNX::LEAF);
+ output.target_class_ids.push_back(leaf_targetids[curr_id]);
+ output.target_class_nodeids.push_back(curr_nodeid);
+ output.target_class_treeids.push_back(curr_treeid);
+ output.target_class_weights_as_tensor.push_back(leaf_weights[curr_id]);
+
+ // the below are irrelevant for a `LEAF`
+ output.nodes_featureids.push_back(0);
+ output.nodes_truenodeids.push_back(0);
+ output.nodes_falsenodeids.push_back(0);
+ output.nodes_values_as_tensor.push_back(0);
+ if (!nodes_hitrates.empty()) {
+ output.nodes_hitrates.push_back(0);
+ }
+ if (!nodes_missing_value_tracks_true.empty()) {
+ output.nodes_missing_value_tracks_true.push_back(0);
+ }
+
+ return curr_nodeid;
+ }
+
+ output.nodes_featureids.push_back(nodes_featureids[curr_id]);
+ if (!nodes_hitrates.empty()) {
+ output.nodes_hitrates_as_tensor.push_back(nodes_hitrates[curr_id]);
+ }
+ if (!nodes_missing_value_tracks_true.empty()) {
+ output.nodes_missing_value_tracks_true.push_back(nodes_missing_value_tracks_true[curr_id]);
+ }
+
+ // unroll `BRANCH_MEMBER` to a chain of `BRANCH_EQ`
+ if (nodes_modes[curr_id] == NODE_MODE_ONNX::BRANCH_MEMBER) {
+ output.nodes_modes.push_back(NODE_MODE_ONNX::BRANCH_EQ);
+ output.nodes_values_as_tensor.push_back(membership_values_by_id[curr_id][curr_membership_value_id]);
+ } else {
+ output.nodes_modes.push_back(nodes_modes[curr_id]);
+ output.nodes_values_as_tensor.push_back(nodes_splits[curr_id]);
+ }
+
+ size_t falsenodeid_id = output.nodes_falsenodeids.size();
+ output.nodes_falsenodeids.push_back(0); // change after pushing truenode subtree
+
+ int64_t true_nodeid = curr_nodeid + 1;
+ output.nodes_truenodeids.push_back(true_nodeid);
+ true_nodeid = transformInputOneTree(onnxruntime::narrow(nodes_truenodeids[curr_id]),
+ curr_treeid, true_nodeid, 0U, nodes_trueleafs[curr_id] != 0,
+ membership_values_by_id, output);
+
+ int64_t false_nodeid = true_nodeid + 1;
+ output.nodes_falsenodeids[falsenodeid_id] = false_nodeid;
+
+ // if node is `BRANCH_MEMBER` we are unrolling the `membership_values` for that node
+ // therefore if the value is not the last, the `falsenode_id` must be pointing to the "same" node with a different membership value
+ // so in that case we are only moving the pointer for `membership_values`
+ //
+ // otherwise, the `falsenode_id` is pointing to the real falsenode subtree
+ if (nodes_modes[curr_id] == NODE_MODE_ONNX::BRANCH_MEMBER &&
+ curr_membership_value_id + 1 < membership_values_by_id[curr_id].size()) {
+ false_nodeid = transformInputOneTree(curr_id, curr_treeid, false_nodeid, curr_membership_value_id + 1, false,
+ membership_values_by_id, output);
+ } else {
+ false_nodeid = transformInputOneTree(onnxruntime::narrow(nodes_falsenodeids[curr_id]),
+ curr_treeid, false_nodeid, 0U, nodes_falseleafs[curr_id] != 0,
+ membership_values_by_id, output);
+ }
+ return false_nodeid;
+ }
+
+ void transformInputAllTrees(TreeEnsembleAttributesV3& output,
+ std::vector>& membership_values_by_id) const {
+ int64_t curr_treeid = 0;
+ for (const int64_t& tree_root : tree_roots) {
+ size_t tree_root_size_t = onnxruntime::narrow(tree_root);
+ transformInputOneTree(tree_root_size_t, curr_treeid, 0, 0U,
+ nodes_falsenodeids[tree_root_size_t] == nodes_truenodeids[tree_root_size_t],
+ membership_values_by_id, output);
+ curr_treeid++;
+ }
+ }
+};
+
+} // namespace detail
+} // namespace ml
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
index 94f79518ae8da..10d4db0e0e3b0 100644
--- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
+++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
@@ -3,15 +3,21 @@
#pragma once
-#include "tree_ensemble_aggregator.h"
#include
#include "core/platform/threadpool.h"
#include "tree_ensemble_helper.h"
+#include "tree_ensemble_attribute.h"
+#include "tree_ensemble_aggregator.h"
namespace onnxruntime {
namespace ml {
namespace detail {
+/**
+ * These attributes are the kernel attributes. They are different from the onnx operator attributes
+ * to improve the computation efficiency. The initialization consists in moving the onnx attributes
+ * into the kernel attributes.
+ */
class TreeEnsembleCommonAttributes {
public:
int64_t get_target_or_class_count() const { return this->n_targets_or_classes_; }
@@ -57,27 +63,7 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes {
Status Init(int parallel_tree,
int parallel_tree_N,
int parallel_N,
- const std::string& aggregate_function,
- const std::vector& base_values,
- const std::vector& base_values_as_tensor,
- int64_t n_targets_or_classes,
- const std::vector& nodes_falsenodeids,
- const std::vector& nodes_featureids,
- const std::vector& nodes_hitrates,
- const std::vector& nodes_hitrates_as_tensor,
- const std::vector& nodes_missing_value_tracks_true,
- const std::vector& nodes_modes,
- const std::vector& nodes_nodeids,
- const std::vector& nodes_treeids,
- const std::vector& nodes_truenodeids,
- const std::vector& nodes_values,
- const std::vector& nodes_values_as_tensor,
- const std::string& post_transform,
- const std::vector& target_class_ids,
- const std::vector& target_class_nodeids,
- const std::vector& target_class_treeids,
- const std::vector& target_class_weights,
- const std::vector& target_class_weights_as_tensor);
+ const TreeEnsembleAttributesV3& attributes);
protected:
TreeNodeElement* ProcessTreeNodeLeave(TreeNodeElement* root,
@@ -87,49 +73,52 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes {
void ComputeAgg(concurrency::ThreadPool* ttp, const Tensor* X, Tensor* Y, Tensor* label, const AGG& agg) const;
private:
- size_t AddNodes(const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids,
- const InlinedVector& falsenode_ids, const std::vector& nodes_featureids,
- const std::vector& nodes_values_as_tensor, const std::vector& node_values,
- const std::vector& nodes_missing_value_tracks_true, std::vector& updated_mapping,
- int64_t tree_id, const InlinedVector& node_tree_ids);
+ bool CheckIfSubtreesAreEqual(const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector& cmodes,
+ const InlinedVector& truenode_ids, const InlinedVector& falsenode_ids, gsl::span nodes_featureids,
+ gsl::span nodes_values_as_tensor, gsl::span node_values,
+ gsl::span target_class_weights, gsl::span target_class_weights_as_tensor,
+ const InlinedVector& node_tree_ids, InlinedVector> indices);
+ size_t AddNodes(const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids,
+ const InlinedVector& falsenode_ids, gsl::span nodes_featureids,
+ gsl::span nodes_values_as_tensor, gsl::span node_values,
+ gsl::span nodes_missing_value_tracks_true, std::vector& updated_mapping,
+ int64_t tree_id, const InlinedVector& node_tree_ids, gsl::span target_class_weights,
+ gsl::span target_class_weights_as_tensor, InlinedVector>& indices);
};
+// Below is simple implementation of `bit_cast` as it is supported from c++20 and the current supported version is c++17
+// Remove it when that is not the case
+template
+std::enable_if_t<
+ sizeof(To) == sizeof(From) &&
+ std::is_trivially_copyable_v &&
+ std::is_trivially_copyable_v,
+ To>
+ // constexpr support needs compiler magic
+ static bit_cast(const From& src) noexcept {
+ static_assert(std::is_trivially_constructible_v,
+ "This implementation additionally requires "
+ "destination type to be trivially constructible");
+
+ To dst;
+ std::memcpy(&dst, &src, sizeof(To));
+ return dst;
+}
+
+template
+std::conditional_t bit_cast_int(T val) {
+ if constexpr (sizeof(T) == sizeof(uint32_t)) {
+ return bit_cast(val);
+ } else if constexpr (sizeof(T) == sizeof(uint64_t)) {
+ return bit_cast(val);
+ }
+ static_assert(sizeof(T) == sizeof(uint32_t) || sizeof(T) == sizeof(uint64_t));
+}
+
template
Status TreeEnsembleCommon::Init(const OpKernelInfo& info) {
- std::vector base_values_as_tensor, nodes_hitrates_as_tensor,
- nodes_values_as_tensor, target_weights_as_tensor;
-#if !defined(ORT_MINIMAL_BUILD)
- ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "base_values_as_tensor", base_values_as_tensor));
- ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates_as_tensor", nodes_hitrates_as_tensor));
- ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_values_as_tensor", nodes_values_as_tensor));
- ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "target_weights_as_tensor", target_weights_as_tensor));
-#endif
-
- return Init(
- 80,
- 128,
- 50,
- info.GetAttrOrDefault("aggregate_function", "SUM"),
- info.GetAttrsOrDefault("base_values"),
- base_values_as_tensor,
- info.GetAttrOrDefault("n_targets", 0),
- info.GetAttrsOrDefault("nodes_falsenodeids"),
- info.GetAttrsOrDefault("nodes_featureids"),
- info.GetAttrsOrDefault("nodes_hitrates"),
- nodes_hitrates_as_tensor,
- info.GetAttrsOrDefault("nodes_missing_value_tracks_true"),
- info.GetAttrsOrDefault("nodes_modes"),
- info.GetAttrsOrDefault("nodes_nodeids"),
- info.GetAttrsOrDefault("nodes_treeids"),
- info.GetAttrsOrDefault("nodes_truenodeids"),
- info.GetAttrsOrDefault("nodes_values"),
- nodes_values_as_tensor,
- info.GetAttrOrDefault("post_transform", "NONE"),
- info.GetAttrsOrDefault("target_ids"),
- info.GetAttrsOrDefault("target_nodeids"),
- info.GetAttrsOrDefault("target_treeids"),
- info.GetAttrsOrDefault("target_weights"),
- target_weights_as_tensor);
+ TreeEnsembleAttributesV3 attributes(info, false);
+ return Init(80, 128, 50, attributes);
}
template
@@ -137,72 +126,35 @@ Status TreeEnsembleCommon::Init(
int parallel_tree,
int parallel_tree_N,
int parallel_N,
- const std::string& aggregate_function,
- const std::vector& base_values,
- const std::vector& base_values_as_tensor,
- int64_t n_targets_or_classes,
- const std::vector& nodes_falsenodeids,
- const std::vector& nodes_featureids,
- const std::vector& nodes_hitrates,
- const std::vector& nodes_hitrates_as_tensor,
- const std::vector& nodes_missing_value_tracks_true,
- const std::vector& nodes_modes,
- const std::vector& nodes_nodeids,
- const std::vector& nodes_treeids,
- const std::vector& nodes_truenodeids,
- const std::vector& nodes_values,
- const std::vector& nodes_values_as_tensor,
- const std::string& post_transform,
- const std::vector& target_class_ids,
- const std::vector& target_class_nodeids,
- const std::vector& target_class_treeids,
- const std::vector& target_class_weights,
- const std::vector& target_class_weights_as_tensor) {
+ const TreeEnsembleAttributesV3& attributes) {
parallel_tree_ = parallel_tree;
parallel_tree_N_ = parallel_tree_N;
parallel_N_ = parallel_N;
- ORT_ENFORCE(n_targets_or_classes > 0);
- ORT_ENFORCE(nodes_falsenodeids.size() == nodes_featureids.size());
- ORT_ENFORCE(nodes_falsenodeids.size() == nodes_modes.size());
- ORT_ENFORCE(nodes_falsenodeids.size() == nodes_nodeids.size());
- ORT_ENFORCE(nodes_falsenodeids.size() == nodes_treeids.size());
- ORT_ENFORCE(nodes_falsenodeids.size() == nodes_truenodeids.size());
- ORT_ENFORCE(nodes_falsenodeids.size() == nodes_values.size() ||
- nodes_falsenodeids.size() == nodes_values_as_tensor.size());
- ORT_ENFORCE(target_class_ids.size() == target_class_nodeids.size());
- ORT_ENFORCE(target_class_ids.size() == target_class_treeids.size());
- ORT_ENFORCE(target_class_weights.empty() || target_class_ids.size() == target_class_weights.size());
- ORT_ENFORCE(base_values.empty() || base_values_as_tensor.empty());
- ORT_ENFORCE(nodes_hitrates.empty() || nodes_hitrates_as_tensor.empty());
- ORT_ENFORCE(nodes_values.empty() || nodes_values_as_tensor.empty());
- ORT_ENFORCE(target_class_weights.empty() || target_class_weights_as_tensor.empty());
-
- aggregate_function_ = MakeAggregateFunction(aggregate_function);
- post_transform_ = MakeTransform(post_transform);
- if (!base_values_as_tensor.empty()) {
- ORT_ENFORCE(base_values.empty());
- base_values_ = base_values_as_tensor;
+ aggregate_function_ = MakeAggregateFunction(attributes.aggregate_function);
+ post_transform_ = MakeTransform(attributes.post_transform);
+ if (!attributes.base_values_as_tensor.empty()) {
+ ORT_ENFORCE(attributes.base_values.empty());
+ base_values_ = attributes.base_values_as_tensor;
} else {
- base_values_.reserve(base_values.size());
- for (size_t i = 0, limit = base_values.size(); i < limit; ++i) {
- base_values_.push_back(static_cast(base_values[i]));
+ base_values_.reserve(attributes.base_values.size());
+ for (size_t i = 0, limit = attributes.base_values.size(); i < limit; ++i) {
+ base_values_.push_back(static_cast(attributes.base_values[i]));
}
}
- n_targets_or_classes_ = n_targets_or_classes;
+ n_targets_or_classes_ = attributes.n_targets_or_classes;
max_tree_depth_ = 1000;
- ORT_ENFORCE(nodes_modes.size() < std::numeric_limits::max());
// Additional members
size_t limit;
uint32_t i;
- InlinedVector cmodes;
- cmodes.reserve(nodes_modes.size());
+ InlinedVector cmodes;
+ cmodes.reserve(attributes.nodes_modes.size());
same_mode_ = true;
int fpos = -1;
- for (i = 0, limit = nodes_modes.size(); i < limit; ++i) {
- cmodes.push_back(MakeTreeNodeMode(nodes_modes[i]));
- if (cmodes[i] == NODE_MODE::LEAF) continue;
+ for (i = 0, limit = attributes.nodes_modes.size(); i < limit; ++i) {
+ cmodes.push_back(attributes.nodes_modes[i]);
+ if (cmodes[i] == NODE_MODE_ONNX::LEAF) continue;
if (fpos == -1) {
fpos = static_cast(i);
continue;
@@ -210,7 +162,7 @@ Status TreeEnsembleCommon::Init(
if (cmodes[i] != cmodes[fpos]) same_mode_ = false;
}
- n_nodes_ = nodes_treeids.size();
+ n_nodes_ = attributes.nodes_treeids.size();
limit = static_cast(n_nodes_);
InlinedVector node_tree_ids;
node_tree_ids.reserve(limit);
@@ -227,7 +179,7 @@ Status TreeEnsembleCommon::Init(
// Build node_tree_ids and node_tree_ids_map and truenode_ids and falsenode_ids
for (i = 0; i < limit; ++i) {
- TreeNodeElementId node_tree_id{static_cast(nodes_treeids[i]), static_cast(nodes_nodeids[i])};
+ TreeNodeElementId node_tree_id{static_cast(attributes.nodes_treeids[i]), static_cast(attributes.nodes_nodeids[i])};
auto p = node_tree_ids_map.insert(std::pair(node_tree_id, i));
if (!p.second) {
ORT_THROW("Node ", node_tree_id.node_id, " in tree ", node_tree_id.tree_id, " is already there.");
@@ -237,13 +189,13 @@ Status TreeEnsembleCommon::Init(
TreeNodeElementId coor;
for (i = 0; i < limit; ++i) {
- if (cmodes[i] == NODE_MODE::LEAF) {
+ if (cmodes[i] == NODE_MODE_ONNX::LEAF) {
truenode_ids.push_back(0);
falsenode_ids.push_back(0);
} else {
TreeNodeElementId& node_tree_id = node_tree_ids[i];
coor.tree_id = node_tree_id.tree_id;
- coor.node_id = static_cast(nodes_truenodeids[i]);
+ coor.node_id = static_cast(attributes.nodes_truenodeids[i]);
ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_));
auto found = node_tree_ids_map.find(coor);
@@ -255,7 +207,7 @@ Status TreeEnsembleCommon::Init(
}
truenode_ids.emplace_back(found->second);
- coor.node_id = static_cast(nodes_falsenodeids[i]);
+ coor.node_id = static_cast(attributes.nodes_falsenodeids[i]);
ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_));
found = node_tree_ids_map.find(coor);
if (found == node_tree_ids_map.end()) {
@@ -270,41 +222,38 @@ Status TreeEnsembleCommon::Init(
}
}
+ // Sort targets
+ InlinedVector> indices;
+ indices.reserve(attributes.target_class_nodeids.size());
+ for (i = 0, limit = attributes.target_class_nodeids.size(); i < limit; i++) {
+ indices.emplace_back(
+ TreeNodeElementId{attributes.target_class_treeids[i], attributes.target_class_nodeids[i]}, i);
+ }
+
+ std::sort(indices.begin(), indices.end());
+
// Let's construct nodes_ such that the false branch is always the next element in nodes_.
// updated_mapping will translates the old position of each node to the new node position in nodes_.
- std::vector updated_mapping(nodes_treeids.size(), 0);
+ std::vector updated_mapping(attributes.nodes_treeids.size(), 0);
int64_t previous_tree_id = -1;
for (i = 0; i < n_nodes_; ++i) {
if (previous_tree_id == -1 || (previous_tree_id != node_tree_ids[i].tree_id)) {
// New tree.
int64_t tree_id = node_tree_ids[i].tree_id;
size_t root_position =
- AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values,
- nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
+ AddNodes(i, cmodes, truenode_ids, falsenode_ids, attributes.nodes_featureids, attributes.nodes_values_as_tensor, attributes.nodes_values,
+ attributes.nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
+ attributes.target_class_weights, attributes.target_class_weights_as_tensor, indices);
roots_.push_back(&nodes_[root_position]);
previous_tree_id = tree_id;
}
}
-
n_trees_ = roots_.size();
- if (((int64_t)nodes_.size()) != n_nodes_) {
- ORT_THROW("Number of nodes in nodes_ (", nodes_.size(), ") is different from n_nodes (", n_nodes_, ").");
- }
-
- // Sort targets
- InlinedVector> indices;
- indices.reserve(target_class_nodeids.size());
- for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) {
- indices.emplace_back(
- std::pair(TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, i));
- }
-
- std::sort(indices.begin(), indices.end());
TreeNodeElementId ind;
SparseValue w;
size_t indi;
- for (indi = 0, limit = target_class_nodeids.size(); indi < limit; ++indi) {
+ for (indi = 0, limit = attributes.target_class_nodeids.size(); indi < limit; ++indi) {
ind = indices[indi].first;
i = indices[indi].second;
auto found = node_tree_ids_map.find(ind);
@@ -319,9 +268,10 @@ Status TreeEnsembleCommon::Init(
// ORT_THROW("Node ", ind.tree_id, "-", ind.node_id, " is not a leaf.");
continue;
}
- w.i = target_class_ids[i];
- w.value = target_class_weights_as_tensor.empty() ? static_cast(target_class_weights[i])
- : target_class_weights_as_tensor[i];
+ w.i = attributes.target_class_ids[i];
+ w.value = attributes.target_class_weights_as_tensor.empty()
+ ? static_cast(attributes.target_class_weights[i])
+ : attributes.target_class_weights_as_tensor[i];
if (leaf.truenode_or_weight.weight_data.n_weights == 0) {
leaf.truenode_or_weight.weight_data.weight = static_cast(weights_.size());
leaf.value_or_unique_weight = w.value;
@@ -331,7 +281,7 @@ Status TreeEnsembleCommon