Skip to content

Commit

Permalink
Label encoder fusion (#19761)
Browse files Browse the repository at this point in the history
### Description
Created a new `LabelEncoderFusion` pass. This is useful in model that
result from automatic conversion tools related to data-science:
sometimes the produced model contains consecutive `LabelEncoder`-s.
To merge 2 `LabelEncoder`-s the optimizer propagates the outputs of the
first encoder through the second one.


### Motivation and Context
This enhances the capabilities of the `onnxruntime::optimizer` by fusing
consecutive `LabelEncoder` nodes.


### Fusion examples
```
Applying fusion
node1: (a,C) (b,B) (c,A) -> Default: _Unused
node2: (A,1) (B,2) (C,3) -> Default: -1
fused: (a,3) (b,2) (c,1) -> Default: -1
Applying fusion
node1: (a,C) (b,B) (c,A) -> Default: D
node2: (A,a) (B,b) (C,c) (D,d) -> Default: default
fused: (a,c) (b,b) (c,a) -> Default: d
Applying fusion
node1: (a,0) (b,1) (c,2) -> Default: -1
node2: (2,a) (1,b) (0,c) -> Default: default
fused: (a,c) (b,b) (c,a) -> Default: default
Applying fusion
node1: (a,3) (b,2) (c,1) -> Default: -1
node2: (1,a) (2,b) (3,c) -> Default: d
fused: (a,c) (b,b) (c,a) -> Default: d
```

---------

Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
neNasko1 and justinchuby authored Apr 1, 2024
1 parent 523ef04 commit 9d06e1b
Show file tree
Hide file tree
Showing 6 changed files with 425 additions and 0 deletions.
2 changes: 2 additions & 0 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "core/optimizer/identical_children_consolidation.h"
#include "core/optimizer/identity_elimination.h"
#include "core/optimizer/layer_norm_fusion.h"
#include "core/optimizer/label_encoder_fusion.h"
#include "core/optimizer/matmul_activation_fusion.h"
#include "core/optimizer/matmul_add_fusion.h"
#include "core/optimizer/matmul_integer_to_float.h"
Expand Down Expand Up @@ -133,6 +134,7 @@ InlinedVector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
rules.push_back(std::make_unique<MatmulBNFusion>());
rules.push_back(std::make_unique<ClipQuantFusion>());
rules.push_back(std::make_unique<ReluQuantFusion>());
rules.push_back(std::make_unique<LabelEncoderFusion>());
break;

case TransformerLevel::Level2:
Expand Down
162 changes: 162 additions & 0 deletions onnxruntime/core/optimizer/label_encoder_fusion.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <unordered_map>
#include <vector>
#include <string>

#include "core/optimizer/label_encoder_fusion.h"
#include "core/framework/op_node_proto_helper.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/utils.h"

namespace onnxruntime {

#define KEYS_ATTR_NAME(T) ("keys_" + GetTypename<T>() + "s")
#define VALUES_ATTR_NAME(T) ("values_" + GetTypename<T>() + "s")
#define DEFAULT_VALUE_ATTR_NAME(T) ("default_" + GetTypename<T>())

// May be needed somewhere else
// Think about moving into utils
template <typename>
[[maybe_unused]] constexpr bool false_for_T = false;

template <typename T>
std::string GetTypename() {
if constexpr (std::is_same<T, int64_t>()) {
return "int64";
} else if constexpr (std::is_same<T, std::string>()) {
return "string";
} else if constexpr (std::is_same<T, float>()) {
return "float";
} else {
static_assert(false_for_T<T>, "Unsupported type");
}
}

template <typename T1, typename T2, typename T3>
bool LabelEncoderFusion::IsValidForFusion(const Node& node, const Node& next_node) const {
return (node.GetAttributes().find(KEYS_ATTR_NAME(T1)) != node.GetAttributes().end() &&
node.GetAttributes().find(VALUES_ATTR_NAME(T2)) != node.GetAttributes().end() &&
next_node.GetAttributes().find(KEYS_ATTR_NAME(T2)) != next_node.GetAttributes().end() &&
next_node.GetAttributes().find(VALUES_ATTR_NAME(T3)) != next_node.GetAttributes().end());
}

/**
Transform that fuses two consecutive LabelEncoder nodes
into one LabelEncoder node.
*/
bool LabelEncoderFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& /*logger*/) const {
if (!graph_utils::IsSupportedOptypeVersionAndDomain(
node, "LabelEncoder", {2, 4}, "ai.onnx.ml") ||
node.GetOutputEdgesCount() != 1) {
return false;
}

const auto& next_node = *node.OutputNodesBegin();
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "LabelEncoder", {4}, "ai.onnx.ml") ||
// Make sure the two nodes do not span execution providers.
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
return false;
}

if (graph.NodeProducesGraphOutput(node)) {
return false;
}

// Is one of the supported operations
return IsValidForFusion<std::string, std::string, std::string>(node, next_node) ||
IsValidForFusion<std::string, std::string, int64_t>(node, next_node) ||
IsValidForFusion<std::string, int64_t, std::string>(node, next_node) ||
IsValidForFusion<std::string, int64_t, int64_t>(node, next_node) ||
IsValidForFusion<int64_t, std::string, std::string>(node, next_node) ||
IsValidForFusion<int64_t, std::string, int64_t>(node, next_node) ||
IsValidForFusion<int64_t, int64_t, std::string>(node, next_node) ||
IsValidForFusion<int64_t, int64_t, int64_t>(node, next_node);
}

/**
Since we need to be polymorphic on the datatype
we will dispatch to this method from the main Apply
*/
template <typename T1, typename T2, typename T3>
Status LabelEncoderFusion::ApplyHelper(
Graph& graph,
Node& node,
Node& next_node,
RewriteRuleEffect& rule_effect) const {
ProtoHelperNodeContext node_helper_ctx(node);
OpNodeProtoHelper<ProtoHelperNodeContext> node_helper(&node_helper_ctx);

ProtoHelperNodeContext next_node_helper_ctx(next_node);
OpNodeProtoHelper<ProtoHelperNodeContext> next_node_helper(&next_node_helper_ctx);

const std::vector<T1> node_keys =
node_helper.GetAttrsOrDefault<T1>(KEYS_ATTR_NAME(T1));
const std::vector<T2> node_values =
node_helper.GetAttrsOrDefault<T2>(VALUES_ATTR_NAME(T2));
const T2 node_default =
node_helper.GetAttr<T2>(DEFAULT_VALUE_ATTR_NAME(T2));

const std::vector<T2> next_node_keys =
next_node_helper.GetAttrsOrDefault<T2>(KEYS_ATTR_NAME(T2));
const std::vector<T3> next_node_values =
next_node_helper.GetAttrsOrDefault<T3>(VALUES_ATTR_NAME(T3));
const T3 next_node_default =
next_node_helper.GetAttr<T3>(DEFAULT_VALUE_ATTR_NAME(T3));

const auto getFromMapDefault = [](const auto& mp, const auto key, const auto def) {
return (mp.find(key) == mp.end()) ? def : mp.at(key);
};

// Perform value propagation through the second label encoder
std::unordered_map<T2, T3> mapping = {};
for (size_t i = 0; i < next_node_keys.size(); i++) {
mapping[next_node_keys[i]] = next_node_values[i];
}

std::vector<T3> new_node_values = {};
const auto new_node_default = getFromMapDefault(mapping, node_default, next_node_default);

for (const T2& node_value : node_values) {
new_node_values.push_back(getFromMapDefault(mapping, node_value, next_node_default));
}

// Remove old attributes:
// The keys attribute is correct, we just reroute
// the values
node.ClearAttribute(VALUES_ATTR_NAME(T2));
node.ClearAttribute(DEFAULT_VALUE_ATTR_NAME(T2));

node.AddAttribute(VALUES_ATTR_NAME(T3), new_node_values);
node.AddAttribute(DEFAULT_VALUE_ATTR_NAME(T3), new_node_default);

graph_utils::FinalizeNodeFusion(graph, node, next_node);

rule_effect = RewriteRuleEffect::kModifiedRestOfGraph;

return Status::OK();
}

#define FUSE_IF_VALID(T1, T2, T3) \
if (IsValidForFusion<T1, T2, T3>(node, next_node)) { \
return ApplyHelper<T1, T2, T3>( \
graph, node, next_node, rule_effect); \
}

Status LabelEncoderFusion::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& /*logger*/) const {
auto& next_node = *graph.GetNode(node.OutputNodesBegin()->Index());

FUSE_IF_VALID(std::string, std::string, std::string);
FUSE_IF_VALID(std::string, std::string, int64_t);
FUSE_IF_VALID(std::string, int64_t, std::string);
FUSE_IF_VALID(std::string, int64_t, int64_t);
FUSE_IF_VALID(int64_t, std::string, std::string);
FUSE_IF_VALID(int64_t, std::string, int64_t);
FUSE_IF_VALID(int64_t, int64_t, std::string);
FUSE_IF_VALID(int64_t, int64_t, int64_t);

return Status::OK();
}

} // namespace onnxruntime
35 changes: 35 additions & 0 deletions onnxruntime/core/optimizer/label_encoder_fusion.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/optimizer/rewrite_rule.h"

namespace onnxruntime {
/**
@Class LabelEncoderFusion
Rewrite rule that fuses two LabelEncoder -> LabelEncoder nodes to a single
LabelEncoder node.
*/
class LabelEncoderFusion : public RewriteRule {
public:
LabelEncoderFusion() noexcept : RewriteRule("LabelEncoderFusion") {}

std::vector<std::string> TargetOpTypes() const noexcept override {
return {"LabelEncoder"};
}

private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;

template <typename T1, typename T2, typename T3>
Status ApplyHelper(Graph& graph, Node& node, Node& next_node, RewriteRuleEffect& rule_effect) const;

template <typename T1, typename T2, typename T3>
bool IsValidForFusion(const Node& node, const Node& next) const;
};

} // namespace onnxruntime
63 changes: 63 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
#include "core/optimizer/slice_elimination.h"
#include "core/optimizer/unsqueeze_elimination.h"
#include "core/optimizer/utils.h"
#include "core/optimizer/label_encoder_fusion.h"
#include "core/platform/env.h"
#include "core/session/inference_session.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
Expand Down Expand Up @@ -1901,6 +1902,68 @@ TEST_F(GraphTransformationTests, DivMulFusion) {
ASSERT_TRUE(op_to_count["Mul"] == 2);
}

TEST_F(GraphTransformationTests, LabelEncoderFusion) {
using common::INVALID_GRAPH;
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/label_encoder.onnx";

NameMLValMap feeds;

constexpr size_t ALPH = 26;
OrtValue mlvalue_a;
std::vector<int64_t> dims_a = {ALPH};
std::vector<std::string> values_a = {};
for (char letter = 'a'; letter <= 'z'; letter++) {
values_a.emplace_back(1, letter);
}
CreateMLValue<std::string>(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], dims_a,
values_a, &mlvalue_a);
feeds.insert(std::make_pair("A", mlvalue_a));

bool is_implemented = true;

auto run_model_test = [&](TransformerLevel level, std::vector<OrtValue>& fetches, const int requiredLabelEncoderCount) {
SessionOptions session_options;
session_options.graph_optimization_level = level;
session_options.session_logid = "OptimizerTests";
InferenceSessionWrapper session{session_options, GetEnvironment()};

// If we did not initialize the session correctly, the operator is missing.
if (!session.Load(model_uri).IsOK() || !session.Initialize().IsOK()) {
is_implemented = false;
return;
}

// Count if the number of LabelEncoders is as expected
std::map<std::string, int> op_to_count = CountOpsInGraph(session.GetGraph());
ASSERT_TRUE(op_to_count["ai.onnx.ml.LabelEncoder"] == requiredLabelEncoderCount);

std::vector<std::string> output_names = {};
for (const auto& output : session.GetGraph().GetOutputs()) {
output_names.push_back(output->Name());
}

RunOptions run_options;
ASSERT_STATUS_OK(session.Run(run_options, feeds, output_names, &fetches));
};

// run model with and w/o optimizations and compare the results
std::vector<OrtValue> unoptimized_fetches;
run_model_test(TransformerLevel::Default, unoptimized_fetches, 11);

std::vector<OrtValue> optimized_fetches;
run_model_test(TransformerLevel::MaxLevel, optimized_fetches, 7);

// If there was a problem loading the model, do not compare the 2 results
if (!is_implemented) {
GTEST_SKIP();
return;
}

// Compare results
auto ret = CompareOrtValue(optimized_fetches[0], unoptimized_fetches[0], 0.0, 0.0, false);
EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second;
}

TEST_F(GraphTransformationTests, NotWhereFusion) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/not_where.onnx";
std::shared_ptr<Model> model;
Expand Down
Binary file not shown.
Loading

0 comments on commit 9d06e1b

Please sign in to comment.