From a2ba3cb54789458fc53e05ec102965d14e0866c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 22 Nov 2024 19:48:23 +0100 Subject: [PATCH] Implementation of TreeEnsemble ai.onnx.ml==5 (#22333) ### Description Merges PR #21851, #21222. Implements TreeEnsemble from ai.onnx.ml==5 (CPU). --------- Co-authored-by: Bilyana Indzheva Co-authored-by: Bilyana Indzheva <36890669+bili2002@users.noreply.github.com> Co-authored-by: Christian Bourjau --- docs/OperatorKernels.md | 1 + .../providers/cpu/cpu_execution_provider.cc | 6 + onnxruntime/core/providers/cpu/ml/ml_common.h | 58 +- .../core/providers/cpu/ml/tree_ensemble.cc | 59 ++ .../core/providers/cpu/ml/tree_ensemble.h | 25 + .../cpu/ml/tree_ensemble_aggregator.h | 40 +- .../cpu/ml/tree_ensemble_attribute.h | 321 ++++++++++ .../providers/cpu/ml/tree_ensemble_common.h | 547 +++++++++--------- .../providers/cpu/ml/tree_ensemble_helper.cc | 66 +-- .../providers/cpu/ml/tree_ensemble_helper.h | 2 + .../python/tools/transformers/float16.py | 1 + .../providers/cpu/ml/tree_ensembler_test.cc | 294 ++++++++++ .../providers/cpu/ml/treeregressor_test.cc | 84 +++ 13 files changed, 1155 insertions(+), 349 deletions(-) create mode 100644 onnxruntime/core/providers/cpu/ml/tree_ensemble.cc create mode 100644 onnxruntime/core/providers/cpu/ml/tree_ensemble.h create mode 100644 onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h create mode 100644 onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index b376ae03e9cdd..eeb8ebb3ccefe 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -453,6 +453,7 @@ Do not modify directly.* |SVMClassifier|*in* X:**T1**
*out* Y:**T2**
*out* Z:**tensor(float)**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int64), tensor(string)| |SVMRegressor|*in* X:**T**
*out* Y:**tensor(float)**|1+|**T** = tensor(float)| |Scaler|*in* X:**T**
*out* Y:**tensor(float)**|1+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|TreeEnsemble|*in* X:**T**
*out* Y:**T**|5+|**T** = tensor(double), tensor(float)| |TreeEnsembleClassifier|*in* X:**T1**
*out* Y:**T2**
*out* Z:**tensor(float)**|3+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int64), tensor(string)| |||[1, 2]|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int64), tensor(string)| |TreeEnsembleRegressor|*in* X:**T**
*out* Y:**tensor(float)**|3+|**T** = tensor(double), tensor(float)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 65eeb4b84e193..0499a15e1df0a 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -2925,6 +2925,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, int32_t, TreeEnsembleClassifier); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, float, TreeEnsembleRegressor); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 3, double, TreeEnsembleRegressor); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, float, TreeEnsemble); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 5, double, TreeEnsemble); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, float_string, LabelEncoder); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 4, string_float, LabelEncoder); @@ -3043,6 +3045,10 @@ Status RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) { TreeEnsembleRegressor)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/ml/ml_common.h b/onnxruntime/core/providers/cpu/ml/ml_common.h index 2f4ebeabe043e..3359b2a69fe83 100644 --- a/onnxruntime/core/providers/cpu/ml/ml_common.h +++ b/onnxruntime/core/providers/cpu/ml/ml_common.h @@ -20,44 +20,48 @@ enum class OUTPUT_MODE { ALL_SCORES }; -enum NODE_MODE : uint8_t { - LEAF = 1, - BRANCH_LEQ = 2, - BRANCH_LT = 4, - BRANCH_GTE = 6, - BRANCH_GT = 8, - BRANCH_EQ = 10, - BRANCH_NEQ = 12 +enum NODE_MODE_ONNX : uint8_t { + BRANCH_LEQ = 0, + BRANCH_LT = 1, + BRANCH_GTE = 2, + BRANCH_GT = 3, + BRANCH_EQ = 4, + BRANCH_NEQ = 5, + BRANCH_MEMBER = 6, + LEAF = 7, }; -static inline NODE_MODE MakeTreeNodeMode(const std::string& input) { +static inline NODE_MODE_ONNX MakeTreeNodeMode(const std::string& input) { if (input == "BRANCH_LEQ") { - return NODE_MODE::BRANCH_LEQ; + return NODE_MODE_ONNX::BRANCH_LEQ; } if (input == "LEAF") { - return NODE_MODE::LEAF; + return NODE_MODE_ONNX::LEAF; } if (input == "BRANCH_LT") { - return NODE_MODE::BRANCH_LT; + return NODE_MODE_ONNX::BRANCH_LT; } if (input == "BRANCH_GTE") { - return NODE_MODE::BRANCH_GTE; + return NODE_MODE_ONNX::BRANCH_GTE; } if (input == "BRANCH_GT") { - return NODE_MODE::BRANCH_GT; + return NODE_MODE_ONNX::BRANCH_GT; } if (input == "BRANCH_EQ") { - return NODE_MODE::BRANCH_EQ; + return NODE_MODE_ONNX::BRANCH_EQ; } - return NODE_MODE::BRANCH_NEQ; + if (input == "BRANCH_MEMBER") { + return NODE_MODE_ONNX::BRANCH_MEMBER; + } + return NODE_MODE_ONNX::BRANCH_NEQ; } -enum class POST_EVAL_TRANSFORM { - NONE, - LOGISTIC, - SOFTMAX, - SOFTMAX_ZERO, - PROBIT +enum class POST_EVAL_TRANSFORM : int64_t { + NONE = 0, + LOGISTIC = 1, + SOFTMAX = 2, + SOFTMAX_ZERO = 3, + PROBIT = 4 }; static inline POST_EVAL_TRANSFORM MakeTransform(const std::string& input) { @@ -76,11 +80,11 @@ static inline POST_EVAL_TRANSFORM MakeTransform(const std::string& input) { return POST_EVAL_TRANSFORM::PROBIT; } -enum class AGGREGATE_FUNCTION { - AVERAGE, - SUM, - MIN, - MAX +enum class AGGREGATE_FUNCTION : int64_t { + AVERAGE = 0, + SUM = 1, + MIN = 2, + MAX = 3 }; static inline AGGREGATE_FUNCTION MakeAggregateFunction(const std::string& input) { diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble.cc b/onnxruntime/core/providers/cpu/ml/tree_ensemble.cc new file mode 100644 index 0000000000000..3ff501d96b72d --- /dev/null +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble.cc @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/ml/tree_ensemble.h" +#include "core/providers/cpu/ml/tree_ensemble_helper.h" +#include "core/common/inlined_containers_fwd.h" + +namespace onnxruntime { +namespace ml { + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + TreeEnsemble, + 5, + float, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()).MayInplace(0, 0), + TreeEnsemble); + +ONNX_CPU_OPERATOR_TYPED_ML_KERNEL( + TreeEnsemble, + 5, + double, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()).MayInplace(0, 0), + TreeEnsemble); + +template +TreeEnsemble::TreeEnsemble(const OpKernelInfo& info) : OpKernel(info) { + if constexpr (std::is_same::value) { + p_tree_ensemble_ = std::make_unique>(); + } else { + p_tree_ensemble_ = std::make_unique>(); + } + ORT_THROW_IF_ERROR(p_tree_ensemble_->Init(info)); +} + +template +Status TreeEnsemble::GetRemovableAttributes(InlinedVector& removable_attributes) const { + InlinedVector names{ + "leaf_targetids", "leaf_weights", "membership_values", "nodes_falseleafs", + "nodes_falsenodeids", "nodes_featureids", "nodes_hitrates", "nodes_missing_value_tracks_true", + "nodes_modes", "nodes_splits", "nodes_trueleafs", "nodes_truenodeids"}; + removable_attributes.swap(names); + return Status::OK(); +} + +template +common::Status TreeEnsemble::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + if (X->Shape().NumDimensions() == 0) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Input shape needs to be at least a single dimension."); + } + int64_t N = X->Shape().NumDimensions() == 1 ? 1 : X->Shape()[0]; + Tensor* Y = context->Output(0, {N, p_tree_ensemble_->get_target_or_class_count()}); + return p_tree_ensemble_->compute(context, X, Y, NULL); +} + +} // namespace ml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble.h new file mode 100644 index 0000000000000..697aae045a7e3 --- /dev/null +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "tree_ensemble_common.h" + +namespace onnxruntime { +namespace ml { +template +class TreeEnsemble final : public OpKernel { + typedef T InputType; // input type + typedef float OutputType; // output type + public: + explicit TreeEnsemble(const OpKernelInfo& info); + common::Status Compute(OpKernelContext* context) const override; + Status GetRemovableAttributes(InlinedVector& removable_attributes) const override; + + private: + // Pointer on one instance of + // detail::TreeEnsembleCommonV5 + // where ThresholdType is defined after accessing the attributes. + std::unique_ptr p_tree_ensemble_; +}; +} // namespace ml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h index b031a6f0cefa3..bf3fd37d10f5c 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h @@ -78,6 +78,40 @@ union PtrOrWeight { } weight_data; }; +enum NODE_MODE_ORT : uint8_t { + LEAF = 1, + BRANCH_LEQ = 2, + BRANCH_LT = 4, + BRANCH_GTE = 6, + BRANCH_GT = 8, + BRANCH_EQ = 10, + BRANCH_NEQ = 12, + BRANCH_MEMBER = 14, +}; + +inline NODE_MODE_ORT Convert_NODE_MODE_ONNX_to_ORT(NODE_MODE_ONNX node_mode) { + switch (node_mode) { + case NODE_MODE_ONNX::LEAF: + return NODE_MODE_ORT::LEAF; + case NODE_MODE_ONNX::BRANCH_LEQ: + return NODE_MODE_ORT::BRANCH_LEQ; + case NODE_MODE_ONNX::BRANCH_LT: + return NODE_MODE_ORT::BRANCH_LT; + case NODE_MODE_ONNX::BRANCH_GTE: + return NODE_MODE_ORT::BRANCH_GTE; + case NODE_MODE_ONNX::BRANCH_GT: + return NODE_MODE_ORT::BRANCH_GT; + case NODE_MODE_ONNX::BRANCH_EQ: + return NODE_MODE_ORT::BRANCH_EQ; + case NODE_MODE_ONNX::BRANCH_NEQ: + return NODE_MODE_ORT::BRANCH_NEQ; + case NODE_MODE_ONNX::BRANCH_MEMBER: + return NODE_MODE_ORT::BRANCH_MEMBER; + default: + ORT_THROW("Unexpected value for node_mode"); + }; +} + template struct TreeNodeElement { int feature_id; @@ -98,10 +132,10 @@ struct TreeNodeElement { // weight in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one, the weight is also // stored in `value_or_unique_weight`. PtrOrWeight truenode_or_weight; - uint8_t flags; + NODE_MODE_ORT flags; - inline NODE_MODE mode() const { return NODE_MODE(flags & 0xF); } - inline bool is_not_leaf() const { return !(flags & NODE_MODE::LEAF); } + inline NODE_MODE_ORT mode() const { return NODE_MODE_ORT(flags & 0xF); } + inline bool is_not_leaf() const { return !(flags & NODE_MODE_ORT::LEAF); } inline bool is_missing_track_true() const { return flags & MissingTrack::kTrue; } }; diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h new file mode 100644 index 0000000000000..d2d1ba9863ac7 --- /dev/null +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_attribute.h @@ -0,0 +1,321 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/inlined_containers.h" +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "ml_common.h" +#include "tree_ensemble_helper.h" +#include + +namespace onnxruntime { +namespace ml { +namespace detail { + +inline bool _isnan_(float x) { return std::isnan(x); } +inline bool _isnan_(double x) { return std::isnan(x); } +inline bool _isnan_(int64_t) { return false; } +inline bool _isnan_(int32_t) { return false; } + +template +struct TreeEnsembleAttributesV3 { + TreeEnsembleAttributesV3() {} + TreeEnsembleAttributesV3(const OpKernelInfo& info, bool classifier) { +#if !defined(ORT_MINIMAL_BUILD) + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "base_values_as_tensor", base_values_as_tensor)); + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates_as_tensor", nodes_hitrates_as_tensor)); + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_values_as_tensor", nodes_values_as_tensor)); + if (classifier) { + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "class_weights_as_tensor", target_class_weights_as_tensor)); + } else { + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "target_weights_as_tensor", target_class_weights_as_tensor)); + } +#endif + + aggregate_function = info.GetAttrOrDefault("aggregate_function", "SUM"); + base_values = info.GetAttrsOrDefault("base_values"); + nodes_falsenodeids = info.GetAttrsOrDefault("nodes_falsenodeids"); + nodes_featureids = info.GetAttrsOrDefault("nodes_featureids"); + nodes_missing_value_tracks_true = info.GetAttrsOrDefault("nodes_missing_value_tracks_true"); + + std::vector nodes_modes_string = info.GetAttrsOrDefault("nodes_modes"); + nodes_modes.reserve(nodes_modes_string.size()); + for (auto s : nodes_modes_string) { + nodes_modes.emplace_back(MakeTreeNodeMode(s)); + } + + nodes_nodeids = info.GetAttrsOrDefault("nodes_nodeids"); + nodes_treeids = info.GetAttrsOrDefault("nodes_treeids"); + nodes_truenodeids = info.GetAttrsOrDefault("nodes_truenodeids"); + nodes_values = info.GetAttrsOrDefault("nodes_values"); + post_transform = info.GetAttrOrDefault("post_transform", "NONE"); + + if (classifier) { + target_class_ids = info.GetAttrsOrDefault("class_ids"); + target_class_nodeids = info.GetAttrsOrDefault("class_nodeids"); + target_class_treeids = info.GetAttrsOrDefault("class_treeids"); + target_class_weights = info.GetAttrsOrDefault("class_weights"); + classlabels_strings = info.GetAttrsOrDefault("classlabels_strings"); + classlabels_int64s = info.GetAttrsOrDefault("classlabels_int64s"); + n_targets_or_classes = classlabels_strings.empty() ? classlabels_int64s.size() + : classlabels_strings.size(); + } else { + n_targets_or_classes = info.GetAttrOrDefault("n_targets", 0); + target_class_ids = info.GetAttrsOrDefault("target_ids"); + target_class_nodeids = info.GetAttrsOrDefault("target_nodeids"); + target_class_treeids = info.GetAttrsOrDefault("target_treeids"); + target_class_weights = info.GetAttrsOrDefault("target_weights"); + + ORT_ENFORCE(n_targets_or_classes > 0); + ORT_ENFORCE(nodes_falsenodeids.size() == nodes_featureids.size()); + ORT_ENFORCE(nodes_falsenodeids.size() == nodes_modes_string.size()); + ORT_ENFORCE(nodes_falsenodeids.size() == nodes_nodeids.size()); + ORT_ENFORCE(nodes_falsenodeids.size() == nodes_treeids.size()); + ORT_ENFORCE(nodes_falsenodeids.size() == nodes_truenodeids.size()); + ORT_ENFORCE(nodes_falsenodeids.size() == nodes_values.size() || + nodes_falsenodeids.size() == nodes_values_as_tensor.size()); + ORT_ENFORCE(target_class_ids.size() == target_class_nodeids.size()); + ORT_ENFORCE(target_class_ids.size() == target_class_treeids.size()); + ORT_ENFORCE(target_class_weights.empty() || target_class_ids.size() == target_class_weights.size()); + ORT_ENFORCE(base_values.empty() || base_values_as_tensor.empty()); + ORT_ENFORCE(nodes_hitrates.empty() || nodes_hitrates_as_tensor.empty()); + ORT_ENFORCE(nodes_values.empty() || nodes_values_as_tensor.empty()); + ORT_ENFORCE(target_class_weights.empty() || target_class_weights_as_tensor.empty()); + ORT_ENFORCE(nodes_modes_string.size() < std::numeric_limits::max()); + } + } + + std::string aggregate_function; + std::vector base_values; + std::vector base_values_as_tensor; + int64_t n_targets_or_classes; + std::vector nodes_falsenodeids; + std::vector nodes_featureids; + std::vector nodes_hitrates; + std::vector nodes_hitrates_as_tensor; + std::vector nodes_missing_value_tracks_true; + std::vector nodes_modes; + std::vector nodes_nodeids; + std::vector nodes_treeids; + std::vector nodes_truenodeids; + std::vector nodes_values; + std::vector nodes_values_as_tensor; + std::string post_transform; + std::vector target_class_ids; + std::vector target_class_nodeids; + std::vector target_class_treeids; + std::vector target_class_weights; + std::vector target_class_weights_as_tensor; + std::vector classlabels_strings; + std::vector classlabels_int64s; + std::vector class_labels; +}; + +template +struct TreeEnsembleAttributesV5 { + TreeEnsembleAttributesV5() {} + TreeEnsembleAttributesV5(const OpKernelInfo& info) { +#if !defined(ORT_MINIMAL_BUILD) + std::vector nodes_modes_i; + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "leaf_weights", leaf_weights)); + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "membership_values", membership_values)); + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates", nodes_hitrates)); + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_modes", nodes_modes_i)); + ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_splits", nodes_splits)); + nodes_modes.reserve(nodes_modes.size()); + for (auto i : nodes_modes_i) { + nodes_modes.push_back(static_cast(i)); + } +#else + // GetVectorAttrsOrDefault is not part of the minimal build. + // As a result, TreeEnsemble v5 cannot be available in this build. + ORT_THROW("TreeEnsemble(ai.onnx.ml==5) is not supported with the minimal build."); +#endif + + aggregate_function = info.GetAttrOrDefault("aggregate_function", 1); + leaf_targetids = info.GetAttrsOrDefault("leaf_targetids"); + n_targets = info.GetAttrOrDefault("n_targets", 0); + nodes_falseleafs = info.GetAttrsOrDefault("nodes_falseleafs"); + nodes_falsenodeids = info.GetAttrsOrDefault("nodes_falsenodeids"); + nodes_featureids = info.GetAttrsOrDefault("nodes_featureids"); + nodes_missing_value_tracks_true = info.GetAttrsOrDefault("nodes_missing_value_tracks_true"); + nodes_trueleafs = info.GetAttrsOrDefault("nodes_trueleafs"); + nodes_truenodeids = info.GetAttrsOrDefault("nodes_truenodeids"); + post_transform = info.GetAttrOrDefault("post_transform", 0); + tree_roots = info.GetAttrsOrDefault("tree_roots"); + } + + void convert_to_v3(TreeEnsembleAttributesV3& output) const { + // Doing all transformations to get the old format. + output.n_targets_or_classes = n_targets; + output.aggregate_function = aggregateFunctionToString(); + output.post_transform = postTransformToString(); + std::vector> membership_values_by_id; + getMembershipValuesById(membership_values_by_id); + transformInputAllTrees(output, membership_values_by_id); + } + + int64_t aggregate_function; + std::vector leaf_targetids; + std::vector leaf_weights; + std::vector membership_values; + int64_t n_targets; + std::vector nodes_falseleafs; + std::vector nodes_falsenodeids; + std::vector nodes_featureids; + std::vector nodes_hitrates; + std::vector nodes_missing_value_tracks_true; + std::vector nodes_modes; + std::vector nodes_splits; + std::vector nodes_trueleafs; + std::vector nodes_truenodeids; + int64_t post_transform; + std::vector tree_roots; + + private: + // `membership_values` are seperated by NAN for different nodes + // It is more convenient to preserve the values for each node in a vector + // The vector would be empty for nodes that are not `BRANCH_MEMBER` + void getMembershipValuesById(std::vector>& membership_values_by_id) const { + membership_values_by_id.clear(); + membership_values_by_id.reserve(nodes_modes.size()); + + size_t curr_id = 0; + for (const auto node_mode : nodes_modes) { + membership_values_by_id.emplace_back(); + if (node_mode != NODE_MODE_ONNX::BRANCH_MEMBER) { + continue; + } + + while (curr_id < membership_values.size() && !_isnan_(membership_values[curr_id])) { + membership_values_by_id.back().push_back(membership_values[curr_id++]); + } + curr_id++; + } + } + + std::string aggregateFunctionToString() const { + switch (aggregate_function) { + case static_cast(AGGREGATE_FUNCTION::AVERAGE): + return "AVERAGE"; + case static_cast(AGGREGATE_FUNCTION::SUM): + return "SUM"; + case static_cast(AGGREGATE_FUNCTION::MIN): + return "MIN"; + case static_cast(AGGREGATE_FUNCTION::MAX): + return "MAX"; + default: + ORT_THROW("Unknown value for aggregate_function."); + } + } + + std::string postTransformToString() const { + switch (post_transform) { + case static_cast(POST_EVAL_TRANSFORM::NONE): + return "NONE"; + case static_cast(POST_EVAL_TRANSFORM::SOFTMAX): + return "SOFTMAX"; + case static_cast(POST_EVAL_TRANSFORM::LOGISTIC): + return "LOGISTIC"; + case static_cast(POST_EVAL_TRANSFORM::SOFTMAX_ZERO): + return "SOFTMAX_ZERO"; + case static_cast(POST_EVAL_TRANSFORM::PROBIT): + return "PROBIT"; + default: + ORT_THROW("Unknown value for post_transform."); + } + } + + int64_t transformInputOneTree( + const size_t curr_id, const int64_t curr_treeid, const int64_t curr_nodeid, const size_t curr_membership_value_id, + const bool is_leaf, std::vector>& membership_values_by_id, + TreeEnsembleAttributesV3& output) const { + output.nodes_nodeids.push_back(curr_nodeid); + output.nodes_treeids.push_back(curr_treeid); + + if (is_leaf) { + output.nodes_modes.push_back(NODE_MODE_ONNX::LEAF); + output.target_class_ids.push_back(leaf_targetids[curr_id]); + output.target_class_nodeids.push_back(curr_nodeid); + output.target_class_treeids.push_back(curr_treeid); + output.target_class_weights_as_tensor.push_back(leaf_weights[curr_id]); + + // the below are irrelevant for a `LEAF` + output.nodes_featureids.push_back(0); + output.nodes_truenodeids.push_back(0); + output.nodes_falsenodeids.push_back(0); + output.nodes_values_as_tensor.push_back(0); + if (!nodes_hitrates.empty()) { + output.nodes_hitrates.push_back(0); + } + if (!nodes_missing_value_tracks_true.empty()) { + output.nodes_missing_value_tracks_true.push_back(0); + } + + return curr_nodeid; + } + + output.nodes_featureids.push_back(nodes_featureids[curr_id]); + if (!nodes_hitrates.empty()) { + output.nodes_hitrates_as_tensor.push_back(nodes_hitrates[curr_id]); + } + if (!nodes_missing_value_tracks_true.empty()) { + output.nodes_missing_value_tracks_true.push_back(nodes_missing_value_tracks_true[curr_id]); + } + + // unroll `BRANCH_MEMBER` to a chain of `BRANCH_EQ` + if (nodes_modes[curr_id] == NODE_MODE_ONNX::BRANCH_MEMBER) { + output.nodes_modes.push_back(NODE_MODE_ONNX::BRANCH_EQ); + output.nodes_values_as_tensor.push_back(membership_values_by_id[curr_id][curr_membership_value_id]); + } else { + output.nodes_modes.push_back(nodes_modes[curr_id]); + output.nodes_values_as_tensor.push_back(nodes_splits[curr_id]); + } + + size_t falsenodeid_id = output.nodes_falsenodeids.size(); + output.nodes_falsenodeids.push_back(0); // change after pushing truenode subtree + + int64_t true_nodeid = curr_nodeid + 1; + output.nodes_truenodeids.push_back(true_nodeid); + true_nodeid = transformInputOneTree(onnxruntime::narrow(nodes_truenodeids[curr_id]), + curr_treeid, true_nodeid, 0U, nodes_trueleafs[curr_id] != 0, + membership_values_by_id, output); + + int64_t false_nodeid = true_nodeid + 1; + output.nodes_falsenodeids[falsenodeid_id] = false_nodeid; + + // if node is `BRANCH_MEMBER` we are unrolling the `membership_values` for that node + // therefore if the value is not the last, the `falsenode_id` must be pointing to the "same" node with a different membership value + // so in that case we are only moving the pointer for `membership_values` + // + // otherwise, the `falsenode_id` is pointing to the real falsenode subtree + if (nodes_modes[curr_id] == NODE_MODE_ONNX::BRANCH_MEMBER && + curr_membership_value_id + 1 < membership_values_by_id[curr_id].size()) { + false_nodeid = transformInputOneTree(curr_id, curr_treeid, false_nodeid, curr_membership_value_id + 1, false, + membership_values_by_id, output); + } else { + false_nodeid = transformInputOneTree(onnxruntime::narrow(nodes_falsenodeids[curr_id]), + curr_treeid, false_nodeid, 0U, nodes_falseleafs[curr_id] != 0, + membership_values_by_id, output); + } + return false_nodeid; + } + + void transformInputAllTrees(TreeEnsembleAttributesV3& output, + std::vector>& membership_values_by_id) const { + int64_t curr_treeid = 0; + for (const int64_t& tree_root : tree_roots) { + size_t tree_root_size_t = onnxruntime::narrow(tree_root); + transformInputOneTree(tree_root_size_t, curr_treeid, 0, 0U, + nodes_falsenodeids[tree_root_size_t] == nodes_truenodeids[tree_root_size_t], + membership_values_by_id, output); + curr_treeid++; + } + } +}; + +} // namespace detail +} // namespace ml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index 94f79518ae8da..10d4db0e0e3b0 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -3,15 +3,21 @@ #pragma once -#include "tree_ensemble_aggregator.h" #include #include "core/platform/threadpool.h" #include "tree_ensemble_helper.h" +#include "tree_ensemble_attribute.h" +#include "tree_ensemble_aggregator.h" namespace onnxruntime { namespace ml { namespace detail { +/** + * These attributes are the kernel attributes. They are different from the onnx operator attributes + * to improve the computation efficiency. The initialization consists in moving the onnx attributes + * into the kernel attributes. + */ class TreeEnsembleCommonAttributes { public: int64_t get_target_or_class_count() const { return this->n_targets_or_classes_; } @@ -57,27 +63,7 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes { Status Init(int parallel_tree, int parallel_tree_N, int parallel_N, - const std::string& aggregate_function, - const std::vector& base_values, - const std::vector& base_values_as_tensor, - int64_t n_targets_or_classes, - const std::vector& nodes_falsenodeids, - const std::vector& nodes_featureids, - const std::vector& nodes_hitrates, - const std::vector& nodes_hitrates_as_tensor, - const std::vector& nodes_missing_value_tracks_true, - const std::vector& nodes_modes, - const std::vector& nodes_nodeids, - const std::vector& nodes_treeids, - const std::vector& nodes_truenodeids, - const std::vector& nodes_values, - const std::vector& nodes_values_as_tensor, - const std::string& post_transform, - const std::vector& target_class_ids, - const std::vector& target_class_nodeids, - const std::vector& target_class_treeids, - const std::vector& target_class_weights, - const std::vector& target_class_weights_as_tensor); + const TreeEnsembleAttributesV3& attributes); protected: TreeNodeElement* ProcessTreeNodeLeave(TreeNodeElement* root, @@ -87,49 +73,52 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes { void ComputeAgg(concurrency::ThreadPool* ttp, const Tensor* X, Tensor* Y, Tensor* label, const AGG& agg) const; private: - size_t AddNodes(const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, - const InlinedVector& falsenode_ids, const std::vector& nodes_featureids, - const std::vector& nodes_values_as_tensor, const std::vector& node_values, - const std::vector& nodes_missing_value_tracks_true, std::vector& updated_mapping, - int64_t tree_id, const InlinedVector& node_tree_ids); + bool CheckIfSubtreesAreEqual(const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector& cmodes, + const InlinedVector& truenode_ids, const InlinedVector& falsenode_ids, gsl::span nodes_featureids, + gsl::span nodes_values_as_tensor, gsl::span node_values, + gsl::span target_class_weights, gsl::span target_class_weights_as_tensor, + const InlinedVector& node_tree_ids, InlinedVector> indices); + size_t AddNodes(const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, + const InlinedVector& falsenode_ids, gsl::span nodes_featureids, + gsl::span nodes_values_as_tensor, gsl::span node_values, + gsl::span nodes_missing_value_tracks_true, std::vector& updated_mapping, + int64_t tree_id, const InlinedVector& node_tree_ids, gsl::span target_class_weights, + gsl::span target_class_weights_as_tensor, InlinedVector>& indices); }; +// Below is simple implementation of `bit_cast` as it is supported from c++20 and the current supported version is c++17 +// Remove it when that is not the case +template +std::enable_if_t< + sizeof(To) == sizeof(From) && + std::is_trivially_copyable_v && + std::is_trivially_copyable_v, + To> + // constexpr support needs compiler magic + static bit_cast(const From& src) noexcept { + static_assert(std::is_trivially_constructible_v, + "This implementation additionally requires " + "destination type to be trivially constructible"); + + To dst; + std::memcpy(&dst, &src, sizeof(To)); + return dst; +} + +template +std::conditional_t bit_cast_int(T val) { + if constexpr (sizeof(T) == sizeof(uint32_t)) { + return bit_cast(val); + } else if constexpr (sizeof(T) == sizeof(uint64_t)) { + return bit_cast(val); + } + static_assert(sizeof(T) == sizeof(uint32_t) || sizeof(T) == sizeof(uint64_t)); +} + template Status TreeEnsembleCommon::Init(const OpKernelInfo& info) { - std::vector base_values_as_tensor, nodes_hitrates_as_tensor, - nodes_values_as_tensor, target_weights_as_tensor; -#if !defined(ORT_MINIMAL_BUILD) - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "base_values_as_tensor", base_values_as_tensor)); - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates_as_tensor", nodes_hitrates_as_tensor)); - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_values_as_tensor", nodes_values_as_tensor)); - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "target_weights_as_tensor", target_weights_as_tensor)); -#endif - - return Init( - 80, - 128, - 50, - info.GetAttrOrDefault("aggregate_function", "SUM"), - info.GetAttrsOrDefault("base_values"), - base_values_as_tensor, - info.GetAttrOrDefault("n_targets", 0), - info.GetAttrsOrDefault("nodes_falsenodeids"), - info.GetAttrsOrDefault("nodes_featureids"), - info.GetAttrsOrDefault("nodes_hitrates"), - nodes_hitrates_as_tensor, - info.GetAttrsOrDefault("nodes_missing_value_tracks_true"), - info.GetAttrsOrDefault("nodes_modes"), - info.GetAttrsOrDefault("nodes_nodeids"), - info.GetAttrsOrDefault("nodes_treeids"), - info.GetAttrsOrDefault("nodes_truenodeids"), - info.GetAttrsOrDefault("nodes_values"), - nodes_values_as_tensor, - info.GetAttrOrDefault("post_transform", "NONE"), - info.GetAttrsOrDefault("target_ids"), - info.GetAttrsOrDefault("target_nodeids"), - info.GetAttrsOrDefault("target_treeids"), - info.GetAttrsOrDefault("target_weights"), - target_weights_as_tensor); + TreeEnsembleAttributesV3 attributes(info, false); + return Init(80, 128, 50, attributes); } template @@ -137,72 +126,35 @@ Status TreeEnsembleCommon::Init( int parallel_tree, int parallel_tree_N, int parallel_N, - const std::string& aggregate_function, - const std::vector& base_values, - const std::vector& base_values_as_tensor, - int64_t n_targets_or_classes, - const std::vector& nodes_falsenodeids, - const std::vector& nodes_featureids, - const std::vector& nodes_hitrates, - const std::vector& nodes_hitrates_as_tensor, - const std::vector& nodes_missing_value_tracks_true, - const std::vector& nodes_modes, - const std::vector& nodes_nodeids, - const std::vector& nodes_treeids, - const std::vector& nodes_truenodeids, - const std::vector& nodes_values, - const std::vector& nodes_values_as_tensor, - const std::string& post_transform, - const std::vector& target_class_ids, - const std::vector& target_class_nodeids, - const std::vector& target_class_treeids, - const std::vector& target_class_weights, - const std::vector& target_class_weights_as_tensor) { + const TreeEnsembleAttributesV3& attributes) { parallel_tree_ = parallel_tree; parallel_tree_N_ = parallel_tree_N; parallel_N_ = parallel_N; - ORT_ENFORCE(n_targets_or_classes > 0); - ORT_ENFORCE(nodes_falsenodeids.size() == nodes_featureids.size()); - ORT_ENFORCE(nodes_falsenodeids.size() == nodes_modes.size()); - ORT_ENFORCE(nodes_falsenodeids.size() == nodes_nodeids.size()); - ORT_ENFORCE(nodes_falsenodeids.size() == nodes_treeids.size()); - ORT_ENFORCE(nodes_falsenodeids.size() == nodes_truenodeids.size()); - ORT_ENFORCE(nodes_falsenodeids.size() == nodes_values.size() || - nodes_falsenodeids.size() == nodes_values_as_tensor.size()); - ORT_ENFORCE(target_class_ids.size() == target_class_nodeids.size()); - ORT_ENFORCE(target_class_ids.size() == target_class_treeids.size()); - ORT_ENFORCE(target_class_weights.empty() || target_class_ids.size() == target_class_weights.size()); - ORT_ENFORCE(base_values.empty() || base_values_as_tensor.empty()); - ORT_ENFORCE(nodes_hitrates.empty() || nodes_hitrates_as_tensor.empty()); - ORT_ENFORCE(nodes_values.empty() || nodes_values_as_tensor.empty()); - ORT_ENFORCE(target_class_weights.empty() || target_class_weights_as_tensor.empty()); - - aggregate_function_ = MakeAggregateFunction(aggregate_function); - post_transform_ = MakeTransform(post_transform); - if (!base_values_as_tensor.empty()) { - ORT_ENFORCE(base_values.empty()); - base_values_ = base_values_as_tensor; + aggregate_function_ = MakeAggregateFunction(attributes.aggregate_function); + post_transform_ = MakeTransform(attributes.post_transform); + if (!attributes.base_values_as_tensor.empty()) { + ORT_ENFORCE(attributes.base_values.empty()); + base_values_ = attributes.base_values_as_tensor; } else { - base_values_.reserve(base_values.size()); - for (size_t i = 0, limit = base_values.size(); i < limit; ++i) { - base_values_.push_back(static_cast(base_values[i])); + base_values_.reserve(attributes.base_values.size()); + for (size_t i = 0, limit = attributes.base_values.size(); i < limit; ++i) { + base_values_.push_back(static_cast(attributes.base_values[i])); } } - n_targets_or_classes_ = n_targets_or_classes; + n_targets_or_classes_ = attributes.n_targets_or_classes; max_tree_depth_ = 1000; - ORT_ENFORCE(nodes_modes.size() < std::numeric_limits::max()); // Additional members size_t limit; uint32_t i; - InlinedVector cmodes; - cmodes.reserve(nodes_modes.size()); + InlinedVector cmodes; + cmodes.reserve(attributes.nodes_modes.size()); same_mode_ = true; int fpos = -1; - for (i = 0, limit = nodes_modes.size(); i < limit; ++i) { - cmodes.push_back(MakeTreeNodeMode(nodes_modes[i])); - if (cmodes[i] == NODE_MODE::LEAF) continue; + for (i = 0, limit = attributes.nodes_modes.size(); i < limit; ++i) { + cmodes.push_back(attributes.nodes_modes[i]); + if (cmodes[i] == NODE_MODE_ONNX::LEAF) continue; if (fpos == -1) { fpos = static_cast(i); continue; @@ -210,7 +162,7 @@ Status TreeEnsembleCommon::Init( if (cmodes[i] != cmodes[fpos]) same_mode_ = false; } - n_nodes_ = nodes_treeids.size(); + n_nodes_ = attributes.nodes_treeids.size(); limit = static_cast(n_nodes_); InlinedVector node_tree_ids; node_tree_ids.reserve(limit); @@ -227,7 +179,7 @@ Status TreeEnsembleCommon::Init( // Build node_tree_ids and node_tree_ids_map and truenode_ids and falsenode_ids for (i = 0; i < limit; ++i) { - TreeNodeElementId node_tree_id{static_cast(nodes_treeids[i]), static_cast(nodes_nodeids[i])}; + TreeNodeElementId node_tree_id{static_cast(attributes.nodes_treeids[i]), static_cast(attributes.nodes_nodeids[i])}; auto p = node_tree_ids_map.insert(std::pair(node_tree_id, i)); if (!p.second) { ORT_THROW("Node ", node_tree_id.node_id, " in tree ", node_tree_id.tree_id, " is already there."); @@ -237,13 +189,13 @@ Status TreeEnsembleCommon::Init( TreeNodeElementId coor; for (i = 0; i < limit; ++i) { - if (cmodes[i] == NODE_MODE::LEAF) { + if (cmodes[i] == NODE_MODE_ONNX::LEAF) { truenode_ids.push_back(0); falsenode_ids.push_back(0); } else { TreeNodeElementId& node_tree_id = node_tree_ids[i]; coor.tree_id = node_tree_id.tree_id; - coor.node_id = static_cast(nodes_truenodeids[i]); + coor.node_id = static_cast(attributes.nodes_truenodeids[i]); ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_)); auto found = node_tree_ids_map.find(coor); @@ -255,7 +207,7 @@ Status TreeEnsembleCommon::Init( } truenode_ids.emplace_back(found->second); - coor.node_id = static_cast(nodes_falsenodeids[i]); + coor.node_id = static_cast(attributes.nodes_falsenodeids[i]); ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_)); found = node_tree_ids_map.find(coor); if (found == node_tree_ids_map.end()) { @@ -270,41 +222,38 @@ Status TreeEnsembleCommon::Init( } } + // Sort targets + InlinedVector> indices; + indices.reserve(attributes.target_class_nodeids.size()); + for (i = 0, limit = attributes.target_class_nodeids.size(); i < limit; i++) { + indices.emplace_back( + TreeNodeElementId{attributes.target_class_treeids[i], attributes.target_class_nodeids[i]}, i); + } + + std::sort(indices.begin(), indices.end()); + // Let's construct nodes_ such that the false branch is always the next element in nodes_. // updated_mapping will translates the old position of each node to the new node position in nodes_. - std::vector updated_mapping(nodes_treeids.size(), 0); + std::vector updated_mapping(attributes.nodes_treeids.size(), 0); int64_t previous_tree_id = -1; for (i = 0; i < n_nodes_; ++i) { if (previous_tree_id == -1 || (previous_tree_id != node_tree_ids[i].tree_id)) { // New tree. int64_t tree_id = node_tree_ids[i].tree_id; size_t root_position = - AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values, - nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + AddNodes(i, cmodes, truenode_ids, falsenode_ids, attributes.nodes_featureids, attributes.nodes_values_as_tensor, attributes.nodes_values, + attributes.nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids, + attributes.target_class_weights, attributes.target_class_weights_as_tensor, indices); roots_.push_back(&nodes_[root_position]); previous_tree_id = tree_id; } } - n_trees_ = roots_.size(); - if (((int64_t)nodes_.size()) != n_nodes_) { - ORT_THROW("Number of nodes in nodes_ (", nodes_.size(), ") is different from n_nodes (", n_nodes_, ")."); - } - - // Sort targets - InlinedVector> indices; - indices.reserve(target_class_nodeids.size()); - for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) { - indices.emplace_back( - std::pair(TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, i)); - } - - std::sort(indices.begin(), indices.end()); TreeNodeElementId ind; SparseValue w; size_t indi; - for (indi = 0, limit = target_class_nodeids.size(); indi < limit; ++indi) { + for (indi = 0, limit = attributes.target_class_nodeids.size(); indi < limit; ++indi) { ind = indices[indi].first; i = indices[indi].second; auto found = node_tree_ids_map.find(ind); @@ -319,9 +268,10 @@ Status TreeEnsembleCommon::Init( // ORT_THROW("Node ", ind.tree_id, "-", ind.node_id, " is not a leaf."); continue; } - w.i = target_class_ids[i]; - w.value = target_class_weights_as_tensor.empty() ? static_cast(target_class_weights[i]) - : target_class_weights_as_tensor[i]; + w.i = attributes.target_class_ids[i]; + w.value = attributes.target_class_weights_as_tensor.empty() + ? static_cast(attributes.target_class_weights[i]) + : attributes.target_class_weights_as_tensor[i]; if (leaf.truenode_or_weight.weight_data.n_weights == 0) { leaf.truenode_or_weight.weight_data.weight = static_cast(weights_.size()); leaf.value_or_unique_weight = w.value; @@ -331,7 +281,7 @@ Status TreeEnsembleCommon::Init( } has_missing_tracks_ = false; - for (auto itm = nodes_missing_value_tracks_true.begin(); itm != nodes_missing_value_tracks_true.end(); ++itm) { + for (auto itm = attributes.nodes_missing_value_tracks_true.begin(); itm != attributes.nodes_missing_value_tracks_true.end(); ++itm) { if (*itm) { has_missing_tracks_ = true; break; @@ -341,13 +291,58 @@ Status TreeEnsembleCommon::Init( return Status::OK(); } +template +bool TreeEnsembleCommon::CheckIfSubtreesAreEqual( + const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector& cmodes, + const InlinedVector& truenode_ids, const InlinedVector& falsenode_ids, gsl::span nodes_featureids, + gsl::span nodes_values_as_tensor, gsl::span node_values, + gsl::span target_class_weights, gsl::span target_class_weights_as_tensor, + const InlinedVector& node_tree_ids, InlinedVector> indices) { + // Leaves have values set at 0 + if (cmodes[left_id] != cmodes[right_id] || nodes_featureids[left_id] != nodes_featureids[right_id] || + (!nodes_values_as_tensor.empty() && nodes_values_as_tensor[left_id] != nodes_values_as_tensor[right_id]) || + (nodes_values_as_tensor.empty() && node_values[left_id] != node_values[right_id])) { + return false; + } + + if (cmodes[left_id] == NODE_MODE_ONNX::LEAF) { + const auto left_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[left_id], uint32_t(0)))->second; + const auto right_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[right_id], uint32_t(0)))->second; + + if (target_class_weights_as_tensor.empty()) { + return target_class_weights[left_target_node] == target_class_weights[right_target_node]; + } else { + return target_class_weights_as_tensor[left_target_node] == target_class_weights_as_tensor[right_target_node]; + } + } + + return CheckIfSubtreesAreEqual(falsenode_ids[left_id], falsenode_ids[right_id], tree_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids, + nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices) && + CheckIfSubtreesAreEqual(truenode_ids[left_id], truenode_ids[right_id], tree_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids, + nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices); +} + +inline void UpdateThreshold(double val, double& mask) { + uint64_t new_mask = bit_cast(mask) | (1ll << (static_cast(val) - 1)); + mask = bit_cast(new_mask); +} + +inline void UpdateThreshold(float val, float& mask) { + uint32_t new_mask = bit_cast(mask) | (1 << (static_cast(val) - 1)); + mask = bit_cast(new_mask); +} + +#define BITCOUNT(T) int64_t(sizeof(T) * 8) +#define CANMASK(v, T) (v >= 1 && v <= BITCOUNT(T)) && v == std::floor(v) + template size_t TreeEnsembleCommon::AddNodes( - const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, - const InlinedVector& falsenode_ids, const std::vector& nodes_featureids, - const std::vector& nodes_values_as_tensor, const std::vector& node_values, - const std::vector& nodes_missing_value_tracks_true, std::vector& updated_mapping, int64_t tree_id, - const InlinedVector& node_tree_ids) { + const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, + const InlinedVector& falsenode_ids, gsl::span nodes_featureids, + gsl::span nodes_values_as_tensor, gsl::span node_values, + gsl::span nodes_missing_value_tracks_true, std::vector& updated_mapping, int64_t tree_id, + const InlinedVector& node_tree_ids, gsl::span target_class_weights, + gsl::span target_class_weights_as_tensor, InlinedVector>& indices) { // Validate this index maps to the same tree_id as the one we should be building. if (node_tree_ids[i].tree_id != tree_id) { ORT_THROW("Tree id mismatch. Expected ", tree_id, " but got ", node_tree_ids[i].tree_id, " at position ", i); @@ -364,28 +359,59 @@ size_t TreeEnsembleCommon::AddNodes( updated_mapping[i] = node_pos; TreeNodeElement node; - node.flags = static_cast(cmodes[i]); + node.flags = Convert_NODE_MODE_ONNX_to_ORT(cmodes[i]); node.feature_id = static_cast(nodes_featureids[i]); if (node.feature_id > max_feature_id_) { max_feature_id_ = node.feature_id; } - node.value_or_unique_weight = - nodes_values_as_tensor.empty() ? static_cast(node_values[i]) : nodes_values_as_tensor[i]; + + node.value_or_unique_weight = 0; + const ThresholdType node_threshold = nodes_values_as_tensor.empty() ? static_cast(node_values[i]) : nodes_values_as_tensor[i]; + if (node.flags == NODE_MODE_ORT::BRANCH_EQ && CANMASK(node_threshold, ThresholdType)) { + UpdateThreshold(node_threshold, node.value_or_unique_weight); + node.flags = NODE_MODE_ORT::BRANCH_MEMBER; + } else { + node.value_or_unique_weight = node_threshold; + } + if (i < static_cast(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) { - node.flags |= static_cast(MissingTrack::kTrue); + node.flags = static_cast(static_cast(node.flags) | static_cast(MissingTrack::kTrue)); } nodes_.push_back(std::move(node)); if (nodes_[node_pos].is_not_leaf()) { + size_t falsenode_id = falsenode_ids[i]; + + // Categoricals are represented as a chain of `EQ` nodes where the subtree for the true child is identical for all nodes in the chain + // Below we are folding together these nodes into one of mode `BRANCH_MEMBER` + // The threshold of this node should be interpreted as a bitmask showing which categoricals values were found in the chain + // Afterwards, when looking whether a feature is included we can do an `and` with the mask of the node + // and the one of the feature (the mask has only one bit set on the place for its value) + // Beware that if a category is bigger than the threshold type, the node stays as `EQ` and no combination is done + if (nodes_[node_pos].flags == NODE_MODE_ORT::BRANCH_MEMBER) { + ThresholdType falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id]; + + while (cmodes[falsenode_id] == NODE_MODE_ONNX::BRANCH_EQ && nodes_[node_pos].feature_id == nodes_featureids[falsenode_id] && + CANMASK(falsenode_threshold, ThresholdType) && + CheckIfSubtreesAreEqual(truenode_ids[i], truenode_ids[falsenode_id], tree_id, cmodes, truenode_ids, falsenode_ids, + nodes_featureids, nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices)) { + UpdateThreshold(falsenode_threshold, nodes_[node_pos].value_or_unique_weight); + falsenode_id = falsenode_ids[falsenode_id]; + falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id]; + } + } + size_t false_branch = - AddNodes(falsenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, - node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + AddNodes(falsenode_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, + node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids, + target_class_weights, target_class_weights_as_tensor, indices); if (false_branch != node_pos + 1) { ORT_THROW("False node must always be the next node, but it isn't at index ", node_pos, " with flags ", static_cast(nodes_[node_pos].flags)); } size_t true_branch = AddNodes(truenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, - node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids, + target_class_weights, target_class_weights_as_tensor, indices); // We don't need to store the false branch pointer since we know it is always in the immediate next entry in nodes_. // nodes_[node_pos].falsenode_inc_or_n_weights.ptr = &nodes_[false_branch]; nodes_[node_pos].truenode_or_weight.ptr = &nodes_[true_branch]; @@ -684,10 +710,12 @@ void TreeEnsembleCommon::ComputeAgg(concur } \ } -inline bool _isnan_(float x) { return std::isnan(x); } -inline bool _isnan_(double x) { return std::isnan(x); } -inline bool _isnan_(int64_t) { return false; } -inline bool _isnan_(int32_t) { return false; } +// Check whether the feature value is set true in the mask +template +inline bool SetMembershipCheck(T1 val, T2 mask) { + const int64_t val_as_int = static_cast(val); + return CANMASK(val, T2) && (((1ll << (val_as_int - 1)) & bit_cast_int(mask)) != 0); +} template TreeNodeElement* @@ -696,7 +724,7 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( InputType val; if (same_mode_) { switch (root->mode()) { - case NODE_MODE::BRANCH_LEQ: + case NODE_MODE_ORT::BRANCH_LEQ: if (has_missing_tracks_) { while (root->is_not_leaf()) { val = x_data[root->feature_id]; @@ -711,22 +739,36 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( } } break; - case NODE_MODE::BRANCH_LT: + case NODE_MODE_ORT::BRANCH_LT: TREE_FIND_VALUE(<) break; - case NODE_MODE::BRANCH_GTE: + case NODE_MODE_ORT::BRANCH_GTE: TREE_FIND_VALUE(>=) break; - case NODE_MODE::BRANCH_GT: + case NODE_MODE_ORT::BRANCH_GT: TREE_FIND_VALUE(>) break; - case NODE_MODE::BRANCH_EQ: + case NODE_MODE_ORT::BRANCH_EQ: TREE_FIND_VALUE(==) break; - case NODE_MODE::BRANCH_NEQ: + case NODE_MODE_ORT::BRANCH_NEQ: TREE_FIND_VALUE(!=) break; - case NODE_MODE::LEAF: + case NODE_MODE_ORT::BRANCH_MEMBER: + if (has_missing_tracks_) { + while (root->is_not_leaf()) { + val = x_data[root->feature_id]; + root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val))) + ? root->truenode_or_weight.ptr + : root + 1; + } + } else { + while (root->is_not_leaf()) { + val = x_data[root->feature_id]; + root = SetMembershipCheck(val, root->value_or_unique_weight) ? root->truenode_or_weight.ptr : root + 1; + } + } + case NODE_MODE_ORT::LEAF: break; } } else { // Different rules to compare to node thresholds. @@ -735,31 +777,36 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( val = x_data[root->feature_id]; threshold = root->value_or_unique_weight; switch (root->mode()) { - case NODE_MODE::BRANCH_LEQ: + case NODE_MODE_ORT::BRANCH_LEQ: root = val <= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr : root + 1; break; - case NODE_MODE::BRANCH_LT: + case NODE_MODE_ORT::BRANCH_LT: root = val < threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr : root + 1; break; - case NODE_MODE::BRANCH_GTE: + case NODE_MODE_ORT::BRANCH_GTE: root = val >= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr : root + 1; break; - case NODE_MODE::BRANCH_GT: + case NODE_MODE_ORT::BRANCH_GT: root = val > threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr : root + 1; break; - case NODE_MODE::BRANCH_EQ: + case NODE_MODE_ORT::BRANCH_EQ: root = val == threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr : root + 1; break; - case NODE_MODE::BRANCH_NEQ: + case NODE_MODE_ORT::BRANCH_NEQ: root = val != threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr : root + 1; break; - case NODE_MODE::LEAF: + case NODE_MODE_ORT::BRANCH_MEMBER: + root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val))) + ? root->truenode_or_weight.ptr + : root + 1; + break; + case NODE_MODE_ORT::LEAF: return root; } } @@ -786,67 +833,13 @@ class TreeEnsembleCommonClassifier : public TreeEnsembleCommon& base_values, - const std::vector& base_values_as_tensor, - const std::vector& nodes_falsenodeids, - const std::vector& nodes_featureids, - const std::vector& nodes_hitrates, - const std::vector& nodes_hitrates_as_tensor, - const std::vector& nodes_missing_value_tracks_true, - const std::vector& nodes_modes, - const std::vector& nodes_nodeids, - const std::vector& nodes_treeids, - const std::vector& nodes_truenodeids, - const std::vector& nodes_values, - const std::vector& nodes_values_as_tensor, - const std::string& post_transform, - const std::vector& class_ids, - const std::vector& class_nodeids, - const std::vector& class_treeids, - const std::vector& class_weights, - const std::vector& class_weights_as_tensor, - const std::vector& classlabels_strings, - const std::vector& classlabels_int64s); + const TreeEnsembleAttributesV3& attributes); }; template Status TreeEnsembleCommonClassifier::Init(const OpKernelInfo& info) { - std::vector base_values_as_tensor, nodes_hitrates_as_tensor, - nodes_values_as_tensor, class_weights_as_tensor; -#if !defined(ORT_MINIMAL_BUILD) - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "base_values_as_tensor", base_values_as_tensor)); - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_hitrates_as_tensor", nodes_hitrates_as_tensor)); - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "nodes_values_as_tensor", nodes_values_as_tensor)); - ORT_THROW_IF_ERROR(GetVectorAttrsOrDefault(info, "class_weights_as_tensor", class_weights_as_tensor)); -#endif - - return Init( - 80, - 128, - 50, - info.GetAttrOrDefault("aggregate_function", "SUM"), - info.GetAttrsOrDefault("base_values"), - base_values_as_tensor, - info.GetAttrsOrDefault("nodes_falsenodeids"), - info.GetAttrsOrDefault("nodes_featureids"), - info.GetAttrsOrDefault("nodes_hitrates"), - nodes_hitrates_as_tensor, - info.GetAttrsOrDefault("nodes_missing_value_tracks_true"), - info.GetAttrsOrDefault("nodes_modes"), - info.GetAttrsOrDefault("nodes_nodeids"), - info.GetAttrsOrDefault("nodes_treeids"), - info.GetAttrsOrDefault("nodes_truenodeids"), - info.GetAttrsOrDefault("nodes_values"), - nodes_values_as_tensor, - info.GetAttrOrDefault("post_transform", "NONE"), - info.GetAttrsOrDefault("class_ids"), - info.GetAttrsOrDefault("class_nodeids"), - info.GetAttrsOrDefault("class_treeids"), - info.GetAttrsOrDefault("class_weights"), - class_weights_as_tensor, - info.GetAttrsOrDefault("classlabels_strings"), - info.GetAttrsOrDefault("classlabels_int64s")); + TreeEnsembleAttributesV3 attributes(info, true); + return Init(80, 128, 50, attributes); } template @@ -854,65 +847,20 @@ Status TreeEnsembleCommonClassifier::Init( int parallel_tree, int parallel_tree_N, int parallel_N, - const std::string& aggregate_function, - const std::vector& base_values, - const std::vector& base_values_as_tensor, - const std::vector& nodes_falsenodeids, - const std::vector& nodes_featureids, - const std::vector& nodes_hitrates, - const std::vector& nodes_hitrates_as_tensor, - const std::vector& nodes_missing_value_tracks_true, - const std::vector& nodes_modes, - const std::vector& nodes_nodeids, - const std::vector& nodes_treeids, - const std::vector& nodes_truenodeids, - const std::vector& nodes_values, - const std::vector& nodes_values_as_tensor, - const std::string& post_transform, - const std::vector& class_ids, - const std::vector& class_nodeids, - const std::vector& class_treeids, - const std::vector& class_weights, - const std::vector& class_weights_as_tensor, - const std::vector& classlabels_strings, - const std::vector& classlabels_int64s) { - auto status = TreeEnsembleCommon::Init( - parallel_tree, - parallel_tree_N, - parallel_N, - aggregate_function, - base_values, - base_values_as_tensor, - classlabels_strings.empty() ? classlabels_int64s.size() - : classlabels_strings.size(), - nodes_falsenodeids, - nodes_featureids, - nodes_hitrates, - nodes_hitrates_as_tensor, - nodes_missing_value_tracks_true, - nodes_modes, - nodes_nodeids, - nodes_treeids, - nodes_truenodeids, - nodes_values, - nodes_values_as_tensor, - post_transform, - class_ids, - class_nodeids, - class_treeids, - class_weights, - class_weights_as_tensor); + const TreeEnsembleAttributesV3& attributes) { + auto status = TreeEnsembleCommon::Init(parallel_tree, parallel_tree_N, parallel_N, attributes); ORT_RETURN_IF_ERROR(status); - classlabels_strings_ = classlabels_strings; - classlabels_int64s_ = classlabels_int64s; + classlabels_strings_ = attributes.classlabels_strings; + classlabels_int64s_ = attributes.classlabels_int64s; InlinedHashSet weights_classes; - weights_classes.reserve(class_ids.size()); + weights_classes.reserve(attributes.target_class_ids.size()); weights_are_all_positive_ = true; - for (size_t i = 0, end = class_ids.size(); i < end; ++i) { - weights_classes.insert(class_ids[i]); - if (weights_are_all_positive_ && (!class_weights.empty() ? class_weights[i] : class_weights_as_tensor[i]) < 0) + for (size_t i = 0, end = attributes.target_class_ids.size(); i < end; ++i) { + weights_classes.insert(attributes.target_class_ids[i]); + if (weights_are_all_positive_ && (!attributes.target_class_weights.empty() ? attributes.target_class_weights[i] + : attributes.target_class_weights_as_tensor[i]) < 0) weights_are_all_positive_ = false; } binary_case_ = this->n_targets_or_classes_ == 2 && weights_classes.size() == 1; @@ -957,6 +905,43 @@ Status TreeEnsembleCommonClassifier::compu return Status::OK(); } +template +class TreeEnsembleCommonV5 : public TreeEnsembleCommon { + public: + virtual Status Init(const OpKernelInfo& info); + + Status Init(int parallel_tree, + int parallel_tree_N, + int parallel_N, + const TreeEnsembleAttributesV5& attributes); +}; + +template +Status TreeEnsembleCommonV5::Init(const OpKernelInfo& info) { + TreeEnsembleAttributesV5 attributes(info); + return Init(80, 128, 50, attributes); +} + +template +Status TreeEnsembleCommonV5::Init( + int parallel_tree, + int parallel_tree_N, + int parallel_N, + const TreeEnsembleAttributesV5& attributes) { + TreeEnsembleAttributesV3 attributes_v3; + attributes.convert_to_v3(attributes_v3); + + attributes_v3.base_values.clear(); + attributes_v3.base_values_as_tensor.clear(); + attributes_v3.nodes_hitrates.clear(); + attributes_v3.nodes_values.clear(); + attributes_v3.target_class_weights.clear(); + + auto status = TreeEnsembleCommon::Init(parallel_tree, parallel_tree_N, parallel_N, attributes_v3); + ORT_RETURN_IF_ERROR(status); + return Status::OK(); +} + } // namespace detail } // namespace ml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc index e2981da3a6f25..399dfd56b93c6 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.cc @@ -5,63 +5,53 @@ #include "core/providers/cpu/ml/tree_ensemble_helper.h" #include "core/common/common.h" +#include "core/common/safeint.h" #include "onnx/defs/tensor_proto_util.h" +#include "core/framework/tensorprotoutils.h" using namespace ::onnxruntime::common; using namespace std; namespace onnxruntime { namespace ml { -Status GetNumberOfElementsAttrsOrDefault(const OpKernelInfo& info, const std::string& name, - ONNX_NAMESPACE::TensorProto_DataType proto_type, - size_t& n_elements, ONNX_NAMESPACE::TensorProto& proto) { - auto status = info.GetAttr(name, &proto); - if (!status.IsOK()) { - // Attribute is missing, n_elements is set to 0. - n_elements = 0; - return Status::OK(); - } - auto n_dims = proto.dims_size(); - if (n_dims == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Attribute:'", name, "' is specified but is empty."); - } - ORT_ENFORCE(n_dims == 1, "Attribute '", name, "' must be a vector."); - ORT_ENFORCE(proto.data_type() == proto_type, - "Unexpected type (", proto.data_type(), "(for attribute '", name, "'."); - - n_elements = onnxruntime::narrow(proto.dims()[0]); - ORT_ENFORCE(n_elements > 0, "Attribute '", name, "' has one dimension but is empty."); - return Status::OK(); -} +template +Status GetAnyVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data) { + ONNX_NAMESPACE::TensorProto proto; + auto result = info.GetAttr(name, &proto); -template -Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, - ONNX_NAMESPACE::TensorProto_DataType proto_type, std::vector& data) { - if (proto_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE) { - ORT_ENFORCE((std::is_same::value)); - } else if (proto_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT) { - ORT_ENFORCE((std::is_same::value)); - } else { - ORT_NOT_IMPLEMENTED("GetVectorAttrsOrDefault not implemented for type ", proto_type); + SafeInt n_elements(1); + for (auto dim : proto.dims()) { + n_elements *= dim; } - ONNX_NAMESPACE::TensorProto proto; - size_t n_elements; - data.clear(); - ORT_THROW_IF_ERROR(GetNumberOfElementsAttrsOrDefault(info, name, proto_type, n_elements, proto)); - if (n_elements == 0) { + if (proto.dims().empty()) { return Status::OK(); } - data = ONNX_NAMESPACE::ParseData(&proto); + + const SafeInt tensor_size(n_elements); + data.clear(); + data.resize(tensor_size); + + result = utils::UnpackTensor(proto, std::filesystem::path(), data.data(), tensor_size); + ORT_ENFORCE(result.IsOK(), "TreeEnsemble could not unpack tensor attribute ", name); + return Status::OK(); } Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data) { - return GetVectorAttrsOrDefault(info, name, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE, data); + return GetAnyVectorAttrsOrDefault(info, name, data); } Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data) { - return GetVectorAttrsOrDefault(info, name, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT, data); + return GetAnyVectorAttrsOrDefault(info, name, data); +} + +Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data) { + return GetAnyVectorAttrsOrDefault(info, name, data); +} + +Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data) { + return GetAnyVectorAttrsOrDefault(info, name, data); } } // namespace ml diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h index 33172c343a88e..ba23f1ad28ec1 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_helper.h @@ -13,6 +13,8 @@ namespace ml { Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data); Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data); +Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data); +Status GetVectorAttrsOrDefault(const OpKernelInfo& info, const std::string& name, std::vector& data); } // namespace ml } // namespace onnxruntime diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index 2398bb9d6031b..74adc951c4aa3 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -132,6 +132,7 @@ def make_value_info_from_tensor(tensor): "Scaler", "TreeEnsembleClassifier", "TreeEnsembleRegressor", + "TreeEnsemble", "ZipMap", "NonMaxSuppression", "TopK", diff --git a/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc b/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc new file mode 100644 index 0000000000000..49bb0ae65d1c9 --- /dev/null +++ b/onnxruntime/test/providers/cpu/ml/tree_ensembler_test.cc @@ -0,0 +1,294 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +static ONNX_NAMESPACE::TensorProto make_tensor(std::vector array, std::string name) { + ONNX_NAMESPACE::TensorProto array_as_tensor; + array_as_tensor.set_name(name); + array_as_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_DOUBLE); + array_as_tensor.add_dims(array.size()); + for (auto v : array) { + array_as_tensor.add_double_data(v); + } + + return array_as_tensor; +} + +static ONNX_NAMESPACE::TensorProto make_tensor(std::vector array, std::string name) { + ONNX_NAMESPACE::TensorProto array_as_tensor; + array_as_tensor.set_name(name); + array_as_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + array_as_tensor.add_dims(array.size()); + for (auto v : array) { + array_as_tensor.add_float_data(v); + } + + return array_as_tensor; +} + +static ONNX_NAMESPACE::TensorProto make_tensor(std::vector array, std::string name) { + ONNX_NAMESPACE::TensorProto array_as_tensor; + array_as_tensor.set_name(name); + array_as_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_UINT8); + array_as_tensor.add_dims(array.size()); + for (const auto v : array) { + array_as_tensor.add_int32_data(v); + } + + return array_as_tensor; +} + +template +void _multiply_update_array(std::vector& data, int n, T inc = 0) { + std::vector copy = data; + data.resize(copy.size() * n); + T cst = 0; + for (int i = 0; i < n; ++i) { + for (size_t j = 0; j < copy.size(); ++j) { + data[j + i * copy.size()] = copy[j] + cst; + } + cst += inc; + } +} + +template +void _multiply_update_childnode(std::vector& childnodes, std::vector& childleafs, std::vector& otherchildleafs, int n) { + int64_t leafs_cnt = 0; + int64_t nodes_cnt = childnodes.size(); + for (auto& childleaf : childleafs) { + if (childleaf) { + leafs_cnt++; + } + } + for (auto& childleaf : otherchildleafs) { + if (childleaf) { + leafs_cnt++; + } + } + + std::vector copy = childnodes; + childnodes.resize(copy.size() * n); + T leafs_cst = 0; + T nodes_cst = 0; + for (int i = 0; i < n; ++i) { + for (size_t j = 0; j < copy.size(); ++j) { + T curr_inc = childleafs[j] ? leafs_cst : nodes_cst; + childnodes[j + i * copy.size()] = copy[j] + curr_inc; + } + + leafs_cst += leafs_cnt; + nodes_cst += nodes_cnt; + } +} + +template +void _multiply_arrays_values(std::vector& data, int64_t val) { + for (auto& curr : data) { + curr *= val; + } +} + +template +void GenTreeAndRunTest(const std::vector& X, const std::vector& Y, const int64_t& aggregate_function, int n_trees = 1) { + OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain); + int64_t n_targets = 2; + + int64_t post_transform = 0; + std::vector tree_roots = {0}; + std::vector nodes_featureids = {0, 0, 0}; + std::vector nodes_modes = {0, 0, 0}; + std::vector nodes_splits = {3.14f, 1.2f, 4.2f}; + std::vector nodes_truenodeids = {1, 0, 1}; + std::vector nodes_trueleafs = {0, 1, 1}; + std::vector nodes_falsenodeids = {2, 2, 3}; + std::vector nodes_falseleafs = {0, 1, 1}; + + std::vector leaf_targetids = {0, 1, 0, 1}; + std::vector leaf_weights = {5.23f, 12.12f, -12.23f, 7.21f}; + + if (n_trees > 1) { + // Multiplies the number of trees to test the parallelization by trees. + _multiply_update_array(tree_roots, n_trees, (int64_t)nodes_truenodeids.size()); + _multiply_update_array(nodes_featureids, n_trees); + _multiply_update_childnode(nodes_truenodeids, nodes_trueleafs, nodes_falseleafs, n_trees); + _multiply_update_childnode(nodes_falsenodeids, nodes_falseleafs, nodes_trueleafs, n_trees); + _multiply_update_array(nodes_trueleafs, n_trees); + _multiply_update_array(nodes_falseleafs, n_trees); + _multiply_update_array(leaf_targetids, n_trees); + _multiply_update_array(nodes_modes, n_trees); + _multiply_update_array(nodes_splits, n_trees); + _multiply_update_array(leaf_weights, n_trees); + } + + auto nodes_modes_as_tensor = make_tensor(nodes_modes, "nodes_modes"); + auto nodes_splits_as_tensor = make_tensor(nodes_splits, "nodes_splits"); + auto leaf_weights_as_tensor = make_tensor(leaf_weights, "leaf_weight"); + + // add attributes + test.AddAttribute("n_targets", n_targets); + test.AddAttribute("aggregate_function", aggregate_function); + test.AddAttribute("post_transform", post_transform); + test.AddAttribute("tree_roots", tree_roots); + test.AddAttribute("nodes_modes", nodes_modes_as_tensor); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_splits", nodes_splits_as_tensor); + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_trueleafs", nodes_trueleafs); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_falseleafs", nodes_falseleafs); + test.AddAttribute("leaf_targetids", leaf_targetids); + test.AddAttribute("leaf_weights", leaf_weights_as_tensor); + + // fill input data + test.AddInput("X", {3, 2}, X); + test.AddOutput("Y", {3, 2}, Y); + test.Run(); +} + +template +void GenTreeAndRunTestWithSetMembership(const std::vector& X, const std::vector& Y, const int64_t& aggregate_function, int n_trees = 1) { + OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain); + int64_t n_targets = 4; + + int64_t post_transform = 0; + std::vector tree_roots = {0}; + std::vector nodes_featureids = {0, 0, 0}; + std::vector nodes_truenodeids = {1, 0, 1}; + std::vector nodes_trueleafs = {0, 1, 1}; + std::vector nodes_falsenodeids = {2, 2, 3}; + std::vector nodes_falseleafs = {1, 0, 1}; + std::vector leaf_targetids = {0, 1, 2, 3}; + + std::vector nodes_modes = {0, 6, 6}; + std::vector nodes_splits = {11.f, 232344.f, NAN}; + std::vector membership_values = {1.2f, 3.7f, 8.f, 9.f, NAN, 12.f, 7.f, NAN}; + std::vector leaf_weights = {1.f, 10.f, 1000.f, 100.f}; + + if (n_trees > 1) { + // Multiplies the number of trees to test the parallelization by trees. + _multiply_update_array(tree_roots, n_trees, (int64_t)nodes_truenodeids.size()); + _multiply_update_array(nodes_featureids, n_trees); + _multiply_update_childnode(nodes_truenodeids, nodes_trueleafs, nodes_falseleafs, n_trees); + _multiply_update_childnode(nodes_falsenodeids, nodes_falseleafs, nodes_trueleafs, n_trees); + _multiply_update_array(nodes_trueleafs, n_trees); + _multiply_update_array(nodes_falseleafs, n_trees); + _multiply_update_array(leaf_targetids, n_trees); + _multiply_update_array(nodes_modes, n_trees); + _multiply_update_array(nodes_splits, n_trees); + _multiply_update_array(membership_values, n_trees); + _multiply_update_array(leaf_weights, n_trees); + } + + auto nodes_modes_as_tensor = make_tensor(nodes_modes, "nodes_modes"); + auto nodes_splits_as_tensor = make_tensor(nodes_splits, "nodes_splits"); + auto membership_values_as_tensor = make_tensor(membership_values, "membership_values"); + auto leaf_weights_as_tensor = make_tensor(leaf_weights, "leaf_weight"); + + // add attributes + test.AddAttribute("n_targets", n_targets); + test.AddAttribute("aggregate_function", aggregate_function); + test.AddAttribute("post_transform", post_transform); + test.AddAttribute("tree_roots", tree_roots); + test.AddAttribute("nodes_modes", nodes_modes_as_tensor); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_splits", nodes_splits_as_tensor); + test.AddAttribute("membership_values", membership_values_as_tensor); + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_trueleafs", nodes_trueleafs); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_falseleafs", nodes_falseleafs); + test.AddAttribute("leaf_targetids", leaf_targetids); + test.AddAttribute("leaf_weights", leaf_weights_as_tensor); + + // fill input data + test.AddInput("X", {6, 1}, X); + test.AddOutput("Y", {6, 4}, Y); + test.Run(); +} + +TEST(MLOpTest, TreeEnsembleFloat) { + std::vector X = {1.2f, 3.4f, -0.12f, 1.66f, 4.14f, 1.77f}; + std::vector Y = {5.23f, 0.f, 5.23f, 0.f, 0.f, 12.12f}; + GenTreeAndRunTest(X, Y, 1, 1); + + Y = {15.69f, 0.f, 15.69f, 0.f, 0.f, 36.36f}; + GenTreeAndRunTest(X, Y, 1, 3); +} + +TEST(MLOpTest, TreeEnsembleDouble) { + std::vector X = {1.2f, 3.4f, -0.12f, 1.66f, 4.14f, 1.77f}; + std::vector Y = {5.23f, 0.f, 5.23f, 0.f, 0.f, 12.12f}; + GenTreeAndRunTest(X, Y, 1, 1); + + _multiply_arrays_values(Y, 3); + GenTreeAndRunTest(X, Y, 1, 3); +} + +TEST(MLOpTest, TreeEnsembleSetMembership) { + std::vector X = {1.2f, 3.4f, -0.12f, NAN, 12.0f, 7.0f}; + std::vector Y = { + 1.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 100.f, + 0.f, 0.f, 0.f, 100.f, + 0.f, 0.f, 1000.f, 0.f, + 0.f, 0.f, 1000.f, 0.f, + 0.f, 10.f, 0.f, 0.f}; + GenTreeAndRunTestWithSetMembership(X, Y, 1, 1); + + _multiply_arrays_values(Y, 5); + GenTreeAndRunTestWithSetMembership(X, Y, 1, 5); +} + +TEST(MLOpTest, TreeEnsembleLeafOnly) { + OpTester test("TreeEnsemble", 5, onnxruntime::kMLDomain); + int64_t n_targets = 1; + + int64_t aggregate_function = 1; + int64_t post_transform = 0; + std::vector tree_roots = {0}; + std::vector nodes_modes = {0}; + std::vector nodes_featureids = {0}; + std::vector nodes_splits = {0.f}; + std::vector nodes_truenodeids = {0}; + std::vector nodes_trueleafs = {1}; + std::vector nodes_falsenodeids = {0}; + std::vector nodes_falseleafs = {1}; + + std::vector leaf_targetids = {0}; + std::vector leaf_weights = {6.23f}; + + auto nodes_modes_as_tensor = make_tensor(nodes_modes, "nodes_modes"); + auto nodes_splits_as_tensor = make_tensor(nodes_splits, "nodes_splits"); + auto leaf_weights_as_tensor = make_tensor(leaf_weights, "leaf_weight"); + + // add attributes + test.AddAttribute("n_targets", n_targets); + test.AddAttribute("aggregate_function", aggregate_function); + test.AddAttribute("post_transform", post_transform); + test.AddAttribute("tree_roots", tree_roots); + test.AddAttribute("nodes_modes", nodes_modes_as_tensor); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_splits", nodes_splits_as_tensor); + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_trueleafs", nodes_trueleafs); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_falseleafs", nodes_falseleafs); + test.AddAttribute("leaf_targetids", leaf_targetids); + test.AddAttribute("leaf_weights", leaf_weights_as_tensor); + + // fill input data + std::vector X = {1.f, 4.f}; + std::vector Y = {6.23f, 6.23f}; + + test.AddInput("X", {2, 1}, X); + test.AddOutput("Y", {2, 1}, Y); + test.Run(); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc index 33c23b53fb5aa..eaf8fea03eaa0 100644 --- a/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc +++ b/onnxruntime/test/providers/cpu/ml/treeregressor_test.cc @@ -679,6 +679,90 @@ TEST(MLOpTest, TreeRegressorSingleTargetSum_as_tensor_precision) { GenTreeAndRunTest1_as_tensor_precision(3); } +TEST(MLOpTest, TreeRegressorCategoricals) { + OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain); + + // tree + int64_t n_targets = 1; + std::vector nodes_featureids = {0, 0, 0, 0, 1, 0, 0}; + std::vector nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF"}; + std::vector nodes_values = {1, 3, 4, 0, 5.5, 0, 0}; + + std::vector nodes_treeids = {0, 0, 0, 0, 0, 0, 0}; + std::vector nodes_nodeids = {0, 1, 2, 3, 4, 5, 6}; + std::vector nodes_falsenodeids = {1, 2, 3, 0, 5, 0, 0}; + std::vector nodes_truenodeids = {4, 4, 4, 0, 6, 0, 0}; + + std::string post_transform = "NONE"; + std::vector target_ids = {0, 0, 0}; + std::vector target_nodeids = {3, 5, 6}; + std::vector target_treeids = {0, 0, 0}; + std::vector target_weights = {-4.699999809265137, 17.700000762939453, 11.100000381469727}; + + // add attributes + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_treeids", nodes_treeids); + test.AddAttribute("nodes_nodeids", nodes_nodeids); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_values", nodes_values); + test.AddAttribute("nodes_modes", nodes_modes); + test.AddAttribute("target_treeids", target_treeids); + test.AddAttribute("target_nodeids", target_nodeids); + test.AddAttribute("target_ids", target_ids); + test.AddAttribute("target_weights", target_weights); + test.AddAttribute("n_targets", n_targets); + + // fill input data + std::vector X = {3.0f, 6.6f, 1.0f, 5.0f, 5.0f, 5.5f}; + std::vector Y = {17.700000762939453, 11.100000381469727, -4.699999809265137}; + test.AddInput("X", {3, 2}, X); + test.AddOutput("Y", {3, 1}, Y); + test.Run(); +} + +TEST(MLOpTest, TreeRegressorCategoricalsFolding) { + OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain); + + // tree + int64_t n_targets = 1; + std::vector nodes_featureids = {0, 0, 1, 1, 0, 0, 0}; + std::vector nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "LEAF", "LEAF"}; + std::vector nodes_values = {1, 3, 2, 3, 0, 0, 0}; + + std::vector nodes_treeids = {0, 0, 0, 0, 0, 0, 0}; + std::vector nodes_nodeids = {0, 1, 2, 3, 4, 5, 6}; + std::vector nodes_falsenodeids = {1, 2, 3, 4, 0, 0, 0}; + std::vector nodes_truenodeids = {5, 5, 6, 6, 0, 0, 0}; + + std::string post_transform = "NONE"; + std::vector target_ids = {0, 0, 0}; + std::vector target_nodeids = {4, 5, 6}; + std::vector target_treeids = {0, 0, 0}; + std::vector target_weights = {17.700000762939453, 11.100000381469727, -4.699999809265137}; + + // add attributes + test.AddAttribute("nodes_truenodeids", nodes_truenodeids); + test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids); + test.AddAttribute("nodes_treeids", nodes_treeids); + test.AddAttribute("nodes_nodeids", nodes_nodeids); + test.AddAttribute("nodes_featureids", nodes_featureids); + test.AddAttribute("nodes_values", nodes_values); + test.AddAttribute("nodes_modes", nodes_modes); + test.AddAttribute("target_treeids", target_treeids); + test.AddAttribute("target_nodeids", target_nodeids); + test.AddAttribute("target_ids", target_ids); + test.AddAttribute("target_weights", target_weights); + test.AddAttribute("n_targets", n_targets); + + // fill input data + std::vector X = {1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f}; + std::vector Y = {11.100000381469727, 11.100000381469727, -4.699999809265137, 17.700000762939453}; + test.AddInput("X", {4, 2}, X); + test.AddOutput("Y", {4, 1}, Y); + test.Run(); +} + TEST(MLOpTest, TreeRegressorTrueNodeBeforeNode) { OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);