Skip to content

Commit

Permalink
Define recomputable op list with domain/opset (microsoft#19722)
Browse files Browse the repository at this point in the history
### Define recomputable op list with domain/opset

Originally, we just check the OpType and decide whether it is
recomputable.

In this PR, few improvements are made:
1. [Op type search] Domain + OpType are used to check whether the op is
supported to recompute.
2. [Opset search] Then, node.SinceVersion() will be searched in the
supported opsets.
3. During subgraph detection, If the node in that this opset is
supported, get the ignorable input indices, which means we don't
consider in the bottom-up search. This would save time for the subgraph
detection.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
pengwa authored and Zhenze Wang committed Mar 7, 2024
1 parent ed01070 commit 91878b2
Show file tree
Hide file tree
Showing 6 changed files with 382 additions and 91 deletions.
9 changes: 8 additions & 1 deletion onnxruntime/core/common/string_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ inline std::string TrimString(std::string s) {
}

/**
* So use this simple hash to generate unique int by given string input.
* @brief A consistent way to construct the full qualified op name.
*/
inline std::string GetFullQualifiedOpName(const std::string& op_type, const std::string& domain) {
return MakeString(domain, "::", op_type);
}

/**
* Use this simple hash to generate unique int by given string input.
*/
inline uint32_t GetHashFromString(const std::string& str_value) {
uint32_t hash = 0;
Expand Down
25 changes: 13 additions & 12 deletions onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#ifdef ENABLE_TRAINING

#include <onnx/defs/attr_proto_util.h>
#include "core/common/string_utils.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/utils.h"
Expand All @@ -26,38 +27,38 @@ UpStreamGatherGraphTransformer::UpStreamGatherGraphTransformer(
// 2. Whether the outputs have the same dim changes if the Gather node moves before that operator.
// 3. Should all inputs be allowed when tracking back further (bottom-up);
// if not, add the input index restriction as MatMul did.
{GetFullQualifiedOpName("Add", kOnnxDomain),
{utils::GetFullQualifiedOpName("Add", kOnnxDomain),
OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<SimplePointwiseGatherActor<true>>(),
opset_14_13_7_6_1)},
{GetFullQualifiedOpName("BiasGelu", kMSDomain),
{utils::GetFullQualifiedOpName("BiasGelu", kMSDomain),
OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<SimplePointwiseGatherActor<true>>(), opset_1)},

{GetFullQualifiedOpName("Cast", kOnnxDomain),
{utils::GetFullQualifiedOpName("Cast", kOnnxDomain),
OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<SimplePointwiseGatherActor<true>>(),
opset_19_13_9_6_1)},
{GetFullQualifiedOpName("Div", kOnnxDomain),
{utils::GetFullQualifiedOpName("Div", kOnnxDomain),
OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<SimplePointwiseGatherActor<true>>(),
opset_14_13_7_6_1)},
{GetFullQualifiedOpName("Dropout", kOnnxDomain),
{utils::GetFullQualifiedOpName("Dropout", kOnnxDomain),
OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<SimplePointwiseGatherActor<true>>(),
opset_13_12_10_7_6_1)},
{GetFullQualifiedOpName("Gelu", kMSDomain),
{utils::GetFullQualifiedOpName("Gelu", kMSDomain),
OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<SimplePointwiseGatherActor<true>>(),
opset_1)},
{// Be noted, this is our own implementation of ONNX domain op.
GetFullQualifiedOpName("LayerNormalization", kOnnxDomain),
utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain),
OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<LayerNormalizationGatherActor>(),
opset_1)},
{GetFullQualifiedOpName("MatMul", kOnnxDomain),
{utils::GetFullQualifiedOpName("MatMul", kOnnxDomain),
OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<MatMulGatherActor>(),
opset_13_9_1)},
{GetFullQualifiedOpName("Reshape", kOnnxDomain),
{utils::GetFullQualifiedOpName("Reshape", kOnnxDomain),
OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<ReshapeGatherActor>(),
opset_19_14_13_5_1)},
{GetFullQualifiedOpName("Softmax", kOnnxDomain),
{utils::GetFullQualifiedOpName("Softmax", kOnnxDomain),
OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<SoftmaxGatherActor>(),
opset_13_11_1)},
{GetFullQualifiedOpName("Transpose", kOnnxDomain),
{utils::GetFullQualifiedOpName("Transpose", kOnnxDomain),
OpPassThroughConfig<UpStreamGatherOperatorActorBase>(std::make_shared<TransposeGatherActor>(),
opset_13_1)},
});
Expand All @@ -69,7 +70,7 @@ bool UpStreamGatherGraphTransformer::UpStreamInternal(
const OpPassThroughConfig<UpStreamGatherOperatorActorBase>& pass_through_config,
const logging::Logger& logger) const {
Node& slice_node = *info.node_ptr;
const std::string op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain());
const std::string op_type = utils::GetFullQualifiedOpName(current_node.OpType(), current_node.Domain());

std::unordered_map<int, int> propagate_input_indices;
std::unordered_map<int, std::vector<DimCompare>> all_input_cmp_rets;
Expand Down
15 changes: 8 additions & 7 deletions onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#ifdef ENABLE_TRAINING

#include "core/framework/tensorprotoutils.h"
#include "core/common/string_utils.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/utils.h"
#include "core/optimizer/compute_optimizer/upstream_reshape_actors.h"
Expand All @@ -21,23 +22,23 @@ UpStreamReshapeGraphTransformer::UpStreamReshapeGraphTransformer(
// If optype is not enough to guarantee the equivalence, we need to add a customized pre-check function.
// 2. Should all inputs be allowed when tracking back further (bottom-up);
// if not, add the input index restriction.
{GetFullQualifiedOpName("Add", kOnnxDomain),
{utils::GetFullQualifiedOpName("Add", kOnnxDomain),
OpPassThroughConfig<UpStreamReshapeOperatorActorBase>(
std::make_shared<SimplePointwiseReshapeActor<true>>(), opset_14_13_7_6_1)},
{GetFullQualifiedOpName("BiasGelu", kMSDomain),
{utils::GetFullQualifiedOpName("BiasGelu", kMSDomain),
OpPassThroughConfig<UpStreamReshapeOperatorActorBase>(
std::make_shared<SimplePointwiseReshapeActor<true>>(), opset_1)},
{GetFullQualifiedOpName("Cast", kOnnxDomain),
{utils::GetFullQualifiedOpName("Cast", kOnnxDomain),
OpPassThroughConfig<UpStreamReshapeOperatorActorBase>(
std::make_shared<SimplePointwiseReshapeActor<true>>(), opset_19_13_9_6_1)},
{GetFullQualifiedOpName("Dropout", kOnnxDomain),
{utils::GetFullQualifiedOpName("Dropout", kOnnxDomain),
OpPassThroughConfig<UpStreamReshapeOperatorActorBase>(
std::make_shared<SimplePointwiseReshapeActor<true>>(), opset_13_12_10_7_6_1)},
{// Be noted, this is our own implementation of ONNX domain op.
GetFullQualifiedOpName("LayerNormalization", kOnnxDomain),
utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain),
OpPassThroughConfig<UpStreamReshapeOperatorActorBase>(
std::make_shared<LayerNormalizationReshapeActor>(), opset_1)},
{GetFullQualifiedOpName("MatMul", kOnnxDomain),
{utils::GetFullQualifiedOpName("MatMul", kOnnxDomain),
OpPassThroughConfig<UpStreamReshapeOperatorActorBase>(
std::make_shared<MatMulReshapeActor>(), opset_13_9_1)},
});
Expand All @@ -47,7 +48,7 @@ bool UpStreamReshapeGraphTransformer::UpStreamInternal(
Graph& graph, std::deque<ReshapeInfo>& queue, Node& current_node, ReshapeInfo& info,
const OpPassThroughConfig<UpStreamReshapeOperatorActorBase>& pass_through_config,
const logging::Logger& logger) const {
const std::string op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain());
const std::string op_type = utils::GetFullQualifiedOpName(current_node.OpType(), current_node.Domain());

std::vector<int> propagate_input_indices;
std::unordered_map<int, std::vector<DimCompare>> all_input_cmp_rets;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <onnx/defs/attr_proto_util.h>
#include "core/common/safeint.h"
#include "core/common/string_utils.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/utils.h"
Expand Down Expand Up @@ -130,7 +131,7 @@ template <typename T1, typename T2>
bool UpStreamGraphTransformerBase<T1, T2>::Upstream(Graph& graph, std::deque<T1>& queue,
Node& current_node, T1& info,
const logging::Logger& logger) const {
const std::string op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain());
const std::string op_type = utils::GetFullQualifiedOpName(current_node.OpType(), current_node.Domain());
if (allowed_passthrough_ops_.count(op_type)) {
auto& pass_through_config = allowed_passthrough_ops_.at(op_type);
LOG_DEBUG_INFO(logger, "Enter reorder handle for node " + current_node.Name() + "(" + op_type + ")");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,6 @@ class UpStreamGraphTransformerBase : public GraphTransformer {
const OpPassThroughConfig<T2>& pass_through_config,
const logging::Logger& logger) const = 0;

/**
* @brief A consistent way to construct the full qualified op name.
*/
std::string GetFullQualifiedOpName(const std::string& op_type, const std::string& domain) const {
return domain + "::" + op_type;
}

std::unordered_map<std::string, OpPassThroughConfig<T2>> allowed_passthrough_ops_;

private:
Expand Down
Loading

0 comments on commit 91878b2

Please sign in to comment.