From 91878b2d810399122f4da51e6d9a4a828691e1c8 Mon Sep 17 00:00:00 2001 From: pengwa Date: Thu, 7 Mar 2024 09:12:12 +0800 Subject: [PATCH] Define recomputable op list with domain/opset (#19722) ### Define recomputable op list with domain/opset Originally, we just check the OpType and decide whether it is recomputable. In this PR, few improvements are made: 1. [Op type search] Domain + OpType are used to check whether the op is supported to recompute. 2. [Opset search] Then, node.SinceVersion() will be searched in the supported opsets. 3. During subgraph detection, If the node in that this opset is supported, get the ignorable input indices, which means we don't consider in the bottom-up search. This would save time for the subgraph detection. ### Motivation and Context --- onnxruntime/core/common/string_utils.h | 9 +- .../compute_optimizer/upstream_gather.cc | 25 +- .../compute_optimizer/upstream_reshape.cc | 15 +- .../upstream_transformer_base.cc | 3 +- .../upstream_transformer_base.h | 7 - .../memory_optimizer/recompute_analysis.cc | 414 +++++++++++++++--- 6 files changed, 382 insertions(+), 91 deletions(-) diff --git a/onnxruntime/core/common/string_utils.h b/onnxruntime/core/common/string_utils.h index 03e94cefd0564..716eed1afec51 100644 --- a/onnxruntime/core/common/string_utils.h +++ b/onnxruntime/core/common/string_utils.h @@ -66,7 +66,14 @@ inline std::string TrimString(std::string s) { } /** - * So use this simple hash to generate unique int by given string input. + * @brief A consistent way to construct the full qualified op name. + */ +inline std::string GetFullQualifiedOpName(const std::string& op_type, const std::string& domain) { + return MakeString(domain, "::", op_type); +} + +/** + * Use this simple hash to generate unique int by given string input. */ inline uint32_t GetHashFromString(const std::string& str_value) { uint32_t hash = 0; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc index 9c98ed6d3e114..1516fb37a7e9f 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc @@ -4,6 +4,7 @@ #ifdef ENABLE_TRAINING #include +#include "core/common/string_utils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/utils.h" @@ -26,38 +27,38 @@ UpStreamGatherGraphTransformer::UpStreamGatherGraphTransformer( // 2. Whether the outputs have the same dim changes if the Gather node moves before that operator. // 3. Should all inputs be allowed when tracking back further (bottom-up); // if not, add the input index restriction as MatMul did. - {GetFullQualifiedOpName("Add", kOnnxDomain), + {utils::GetFullQualifiedOpName("Add", kOnnxDomain), OpPassThroughConfig(std::make_shared>(), opset_14_13_7_6_1)}, - {GetFullQualifiedOpName("BiasGelu", kMSDomain), + {utils::GetFullQualifiedOpName("BiasGelu", kMSDomain), OpPassThroughConfig(std::make_shared>(), opset_1)}, - {GetFullQualifiedOpName("Cast", kOnnxDomain), + {utils::GetFullQualifiedOpName("Cast", kOnnxDomain), OpPassThroughConfig(std::make_shared>(), opset_19_13_9_6_1)}, - {GetFullQualifiedOpName("Div", kOnnxDomain), + {utils::GetFullQualifiedOpName("Div", kOnnxDomain), OpPassThroughConfig(std::make_shared>(), opset_14_13_7_6_1)}, - {GetFullQualifiedOpName("Dropout", kOnnxDomain), + {utils::GetFullQualifiedOpName("Dropout", kOnnxDomain), OpPassThroughConfig(std::make_shared>(), opset_13_12_10_7_6_1)}, - {GetFullQualifiedOpName("Gelu", kMSDomain), + {utils::GetFullQualifiedOpName("Gelu", kMSDomain), OpPassThroughConfig(std::make_shared>(), opset_1)}, {// Be noted, this is our own implementation of ONNX domain op. - GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), + utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_1)}, - {GetFullQualifiedOpName("MatMul", kOnnxDomain), + {utils::GetFullQualifiedOpName("MatMul", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_13_9_1)}, - {GetFullQualifiedOpName("Reshape", kOnnxDomain), + {utils::GetFullQualifiedOpName("Reshape", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_19_14_13_5_1)}, - {GetFullQualifiedOpName("Softmax", kOnnxDomain), + {utils::GetFullQualifiedOpName("Softmax", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_13_11_1)}, - {GetFullQualifiedOpName("Transpose", kOnnxDomain), + {utils::GetFullQualifiedOpName("Transpose", kOnnxDomain), OpPassThroughConfig(std::make_shared(), opset_13_1)}, }); @@ -69,7 +70,7 @@ bool UpStreamGatherGraphTransformer::UpStreamInternal( const OpPassThroughConfig& pass_through_config, const logging::Logger& logger) const { Node& slice_node = *info.node_ptr; - const std::string op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); + const std::string op_type = utils::GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); std::unordered_map propagate_input_indices; std::unordered_map> all_input_cmp_rets; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc index f7b48de2caaf5..716988e93312c 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_reshape.cc @@ -4,6 +4,7 @@ #ifdef ENABLE_TRAINING #include "core/framework/tensorprotoutils.h" +#include "core/common/string_utils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/utils.h" #include "core/optimizer/compute_optimizer/upstream_reshape_actors.h" @@ -21,23 +22,23 @@ UpStreamReshapeGraphTransformer::UpStreamReshapeGraphTransformer( // If optype is not enough to guarantee the equivalence, we need to add a customized pre-check function. // 2. Should all inputs be allowed when tracking back further (bottom-up); // if not, add the input index restriction. - {GetFullQualifiedOpName("Add", kOnnxDomain), + {utils::GetFullQualifiedOpName("Add", kOnnxDomain), OpPassThroughConfig( std::make_shared>(), opset_14_13_7_6_1)}, - {GetFullQualifiedOpName("BiasGelu", kMSDomain), + {utils::GetFullQualifiedOpName("BiasGelu", kMSDomain), OpPassThroughConfig( std::make_shared>(), opset_1)}, - {GetFullQualifiedOpName("Cast", kOnnxDomain), + {utils::GetFullQualifiedOpName("Cast", kOnnxDomain), OpPassThroughConfig( std::make_shared>(), opset_19_13_9_6_1)}, - {GetFullQualifiedOpName("Dropout", kOnnxDomain), + {utils::GetFullQualifiedOpName("Dropout", kOnnxDomain), OpPassThroughConfig( std::make_shared>(), opset_13_12_10_7_6_1)}, {// Be noted, this is our own implementation of ONNX domain op. - GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), + utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), OpPassThroughConfig( std::make_shared(), opset_1)}, - {GetFullQualifiedOpName("MatMul", kOnnxDomain), + {utils::GetFullQualifiedOpName("MatMul", kOnnxDomain), OpPassThroughConfig( std::make_shared(), opset_13_9_1)}, }); @@ -47,7 +48,7 @@ bool UpStreamReshapeGraphTransformer::UpStreamInternal( Graph& graph, std::deque& queue, Node& current_node, ReshapeInfo& info, const OpPassThroughConfig& pass_through_config, const logging::Logger& logger) const { - const std::string op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); + const std::string op_type = utils::GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); std::vector propagate_input_indices; std::unordered_map> all_input_cmp_rets; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc index f08e37296d259..4582f26a7dc68 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.cc @@ -5,6 +5,7 @@ #include #include "core/common/safeint.h" +#include "core/common/string_utils.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/utils.h" @@ -130,7 +131,7 @@ template bool UpStreamGraphTransformerBase::Upstream(Graph& graph, std::deque& queue, Node& current_node, T1& info, const logging::Logger& logger) const { - const std::string op_type = GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); + const std::string op_type = utils::GetFullQualifiedOpName(current_node.OpType(), current_node.Domain()); if (allowed_passthrough_ops_.count(op_type)) { auto& pass_through_config = allowed_passthrough_ops_.at(op_type); LOG_DEBUG_INFO(logger, "Enter reorder handle for node " + current_node.Name() + "(" + op_type + ")"); diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h index 6e22fc791ade3..d848a03c555bb 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_transformer_base.h @@ -72,13 +72,6 @@ class UpStreamGraphTransformerBase : public GraphTransformer { const OpPassThroughConfig& pass_through_config, const logging::Logger& logger) const = 0; - /** - * @brief A consistent way to construct the full qualified op name. - */ - std::string GetFullQualifiedOpName(const std::string& op_type, const std::string& domain) const { - return domain + "::" + op_type; - } - std::unordered_map> allowed_passthrough_ops_; private: diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 76b3325f36116..b421eb2ab32da 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -48,75 +48,352 @@ float InputOutputSizeRatio(const Node* node) { return 1.0f; } +using IgnorableInputIndices = InlinedVector; +using OpsetToIgnorableIndicesMap = InlinedHashMap; + /** - * @brief Used to define per-op recompute config. + * @brief Get the Allowed Recompute Ops object + * + * The supported op types are predefined. + * Most recent revisited for ONNX v1.15.0 release - https://github.com/onnx/onnx/blob/b86cc54efce19530fb953e4b21f57e6b3888534c/docs/Operators.md * + * We defined supported list explicitly instead of using a excluding list for the following reasons: + * 1. Some ops generate indeterministic results (for example using random number generator). We need evaluate whether + * this is a problem for recompute before adding the support, instead of fixing this after we find and try to + * fix convergence issues (which will be very hard if we have multiple indeterministic operators by default supported.) + * 2. Some ops schema will be changed in new opsets, we need also check manually whether it is applicable to recompute + * or not. + * 3. Some ops are not supported in older opsets, we need to check whether it is applicable to recompute or not. */ -struct AllowedRecomputeNodeConfig { - InlinedVector input_arg_indices; // input index to iterate further (bottom up) -}; - -// The supported op types are predefined. - -const InlinedHashMap& GetAllowedRecomputeOps(int probe_op_level) { - static InlinedHashMap> recomputable_op_table_map; +const InlinedHashMap& GetAllowedRecomputeOps(int probe_op_level) { + static InlinedHashMap> recomputable_op_table_map; if (recomputable_op_table_map.find(probe_op_level) != recomputable_op_table_map.end()) { return recomputable_op_table_map.at(probe_op_level); } - recomputable_op_table_map.insert({probe_op_level, InlinedHashMap()}); + recomputable_op_table_map.insert({probe_op_level, InlinedHashMap()}); auto& recomputable_op_table = recomputable_op_table_map.at(probe_op_level); if (probe_op_level >= static_cast(ProbeLevel::Basic)) { recomputable_op_table.insert({ - // Binary elementwise - {"Add", AllowedRecomputeNodeConfig{{0, 1}}}, - {"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Div", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Equal", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Mul", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Sub", AllowedRecomputeNodeConfig{{0, 1}}}, - - // Data layout - /// The shape input is trivial whether it exists or not in backward. - {"Reshape", AllowedRecomputeNodeConfig{{0}}}, - {"Shape", AllowedRecomputeNodeConfig{{0}}}, - {"Squeeze", AllowedRecomputeNodeConfig{{0}}}, - {"Transpose", AllowedRecomputeNodeConfig{{0}}}, - {"Unsqueeze", AllowedRecomputeNodeConfig{{0}}}, - - // Unary elementwise - {"Dropout", AllowedRecomputeNodeConfig{{0}}}, - {"BiasGelu", AllowedRecomputeNodeConfig{{0, 1}}}, - /// The ratio and mode input are trivial whether they exist or not in backward - {"BitmaskDropout", AllowedRecomputeNodeConfig{{0}}}, - /// The axis input is trivial whether it exists or not in backward - {"CumSum", AllowedRecomputeNodeConfig{{0}}}, - {"Expand", AllowedRecomputeNodeConfig{{0}}}, - {"FastGelu", AllowedRecomputeNodeConfig{{0}}}, - {"Gelu", AllowedRecomputeNodeConfig{{0}}}, - {"QuickGelu", AllowedRecomputeNodeConfig{{0}}}, - - // Ternary elementwise - {"Where", AllowedRecomputeNodeConfig{{0, 1, 2}}}, - - // Data copy - {"Tile", AllowedRecomputeNodeConfig{{0}}}, - {"Cast", AllowedRecomputeNodeConfig{{0}}}, - {"ConcatTraining", AllowedRecomputeNodeConfig{{0, 1}}}, // Input could be more than 2. But mostly 2. - {"Slice", AllowedRecomputeNodeConfig{{0}}}, - {"Split", AllowedRecomputeNodeConfig{{0}}}, - {"Gather", AllowedRecomputeNodeConfig{{0}}}, + { + utils::GetFullQualifiedOpName("Add", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BatchNormalization", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {9, {}}, + {14, {}}, + {15, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BiasGelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BiasDropout", kMSDomain), + { + {1, {3, 4}}, // ignore ratio (optional) and training mode (optional) + }, + }, + { + utils::GetFullQualifiedOpName("BitmaskBiasDropout", kMSDomain), + { + {1, {3, 4}}, // ignore ratio (optional) and training mode (optional) + }, + }, + { + utils::GetFullQualifiedOpName("BitmaskDropout", kMSDomain), + { + {1, {1, 2}}, // ignore ratio (optional) and training mode (optional) + }, + }, + { + utils::GetFullQualifiedOpName("Cast", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {9, {}}, + {13, {}}, + {19, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("ConcatTraining", kMSDomain), + { + {1, {}}, + + }, + }, + { + utils::GetFullQualifiedOpName("ConstantOfShape", kOnnxDomain), + { + {9, {0}}, // ignore the `input`, e.g. the shape of the expected output tensor + {20, {0}}, + }, + }, + { + utils::GetFullQualifiedOpName("Dropout", kOnnxDomain), + { + // ONNX Dropout 1, 6, 7, 10 do not have seed attribute, so we remove them from the recompute support. + {12, {1, 2}}, // ignore ratio and training_mode + {13, {1, 2}}, + }, + }, + { + utils::GetFullQualifiedOpName("Div", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Expand", kOnnxDomain), + { + {8, {1}}, // Ignore the shape. + {13, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Cos", kOnnxDomain), + { + {7, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("CumSum", kOnnxDomain), + { + // The axis input is trivial + {11, {1}}, + {14, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Einsum", kOnnxDomain), + { + {12, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Equal", kOnnxDomain), + { + {1, {}}, + {7, {}}, + {11, {}}, + {13, {}}, + {19, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("FastGelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Gather", kOnnxDomain), + { + {1, {1}}, // ignore the indices + {11, {1}}, + {13, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Gelu", kOnnxDomain), + { + {20, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Gelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Less", kOnnxDomain), + { + {1, {}}, + {7, {}}, + {9, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Mul", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Range", kOnnxDomain), + { + {11, {0, 1, 2}}, // ignore start, end, delta, because they are scalars. + }, + }, + { + utils::GetFullQualifiedOpName("Reshape", kOnnxDomain), + { + {1, {}}, + {5, {}}, // ignore the shape. + {13, {}}, + {14, {}}, + {19, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Sin", kOnnxDomain), + { + {7, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Slice", kOnnxDomain), + { + {1, {}}, + {10, {1, 2, 3, 4}}, // ignore starts, ends, axes (optional) and steps (optional) + {11, {1, 2, 3, 4}}, + {13, {1, 2, 3, 4}}, + }, + }, + { + utils::GetFullQualifiedOpName("Split", kOnnxDomain), + { + {1, {1}}, // ignore split (optional) + {2, {}}, + {11, {}}, + {13, {1}}, // ignore the split (optional) + {18, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Squeeze", kOnnxDomain), + { + {1, {}}, + {11, {}}, + {13, {1}}, // ignore the axes (optional) + }, + }, + { + utils::GetFullQualifiedOpName("Sub", kOnnxDomain), + { + {1, {}}, + {6, {}}, + {7, {}}, + {13, {}}, + {14, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Tile", kOnnxDomain), + { + {1, {1, 2}}, + {6, {1}}, + {13, {1}}, + }, + }, + { + utils::GetFullQualifiedOpName("Transpose", kOnnxDomain), + { + {1, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Trilu", kOnnxDomain), + { + {14, {1}}, // ignore k (optional) + }, + }, + { + utils::GetFullQualifiedOpName("QuickGelu", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Unsqueeze", kOnnxDomain), + { + {1, {}}, + {11, {}}, + {13, {1}}, // ignore the axes (optional) + }, + }, + { + utils::GetFullQualifiedOpName("Where", kOnnxDomain), + { + {9, {}}, + {16, {}}, + }, + }, + }); } if (probe_op_level >= static_cast(ProbeLevel::Advanced)) { recomputable_op_table.insert({ - {"LayerNormalization", AllowedRecomputeNodeConfig{{0, 1, 2}}}, - {"MatMul", AllowedRecomputeNodeConfig{{0, 1}}}, - {"FusedMatMul", AllowedRecomputeNodeConfig{{0, 1}}}, - {"Softmax", AllowedRecomputeNodeConfig{{0}}}, - {"BiasSoftmax", AllowedRecomputeNodeConfig{{0, 1}}}, - {"BiasSoftmaxDropout", AllowedRecomputeNodeConfig{{0, 1}}}, + { + utils::GetFullQualifiedOpName("BiasSoftmax", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("BiasSoftmaxDropout", kMSDomain), + { + {1, {2}}, // ignore ratio (optional) + }, + }, + { + utils::GetFullQualifiedOpName("LayerNormalization", kOnnxDomain), + { + // Opset 1 in ONNX official does not have LayerNormalization, + // while our contrib op defined LayerNormalization in opset 1 in ONNX domain. + {1, {}}, + {17, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("MatMul", kOnnxDomain), + { + {1, {}}, + {9, {}}, + {13, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("FusedMatMul", kMSDomain), + { + {1, {}}, + }, + }, + { + utils::GetFullQualifiedOpName("Softmax", kOnnxDomain), + { + {1, {}}, + {11, {}}, + {13, {}}, + }, + }, }); } @@ -127,8 +404,20 @@ const InlinedHashMap& GetAllowedRecompu * @brief Check whether a node is a recomputable node at given probe level. */ bool IsRecomputable(const Node& node, ProbeLevel probe_level) { - const auto& op_table = GetAllowedRecomputeOps(static_cast(probe_level)); - return op_table.find(node.OpType()) != op_table.end(); + const InlinedHashMap& op_table = GetAllowedRecomputeOps(static_cast(probe_level)); + auto it = op_table.find(utils::GetFullQualifiedOpName(node.OpType(), node.Domain())); + if (it == op_table.end()) { + return false; + } + return it->second.count(node.SinceVersion()); +} + +const InlinedVector& GetIgnorableInputIndices(const Node& node, ProbeLevel probe_level) { + const InlinedHashMap& op_table = GetAllowedRecomputeOps(static_cast(probe_level)); + auto it = op_table.find(utils::GetFullQualifiedOpName(node.OpType(), node.Domain())); + ORT_ENFORCE(it != op_table.end(), "Cannot get ignorable indices since the node type is supported in the list."); + ORT_ENFORCE(it->second.count(node.SinceVersion()) > 0, "Cannot get ignorable indices since the opset is supported"); + return it->second.at(node.SinceVersion()); } /** @@ -163,7 +452,6 @@ Status SelectRecomputeSubgraph(const Node& entry_node, bool& can_compromise_stashed_activation, float& save_ratio) { const ProbeLevel probe_level = probe_config.probe_level; - const auto& recomputable_op_table = GetAllowedRecomputeOps(static_cast(probe_level)); can_compromise_stashed_activation = false; @@ -213,7 +501,7 @@ Status SelectRecomputeSubgraph(const Node& entry_node, // If current op is NOT in allowed list: // 1). the output does not exist in backward, we cannot find a good solution for so, the search terminates. // 2). the output is used in backward, we don't need to trace back further, so continue searching. - auto op_recompute_config_it = recomputable_op_table.find(curr_node->OpType()); + bool is_recomputable = IsRecomputable(*curr_node, probe_level); auto cur_output_arg_name = curr_node->OutputDefs()[p.second]->Name(); if (is_first_queue_scan) { // We handle the entry node outputs differently because, we don't want this case falls into and succeed one of @@ -221,14 +509,14 @@ Status SelectRecomputeSubgraph(const Node& entry_node, // 1. "op is not in recompute op list, but its output is used in backward" // 2. "op is in recompute op list, but its output is used in backward" // (either of the above checks is true for entry node outputs) - if (op_recompute_config_it == recomputable_op_table.end()) { + if (!is_recomputable) { early_stop = true; MO_LOG_DEBUG_INFO(logger, "Entry Node " + curr_node->Name() + "(" + curr_node->OpType() + ") is **NOT** in recompute op list, search terminates."); break; } } else { - if (op_recompute_config_it == recomputable_op_table.end()) { + if (!is_recomputable) { if (fw_op_output_arg_used_map.at(cur_output_arg_name).second) { MO_LOG_DEBUG_INFO(logger, "Node " + curr_node->Name() + "(" + curr_node->OpType() + ") is **NOT** in recompute op list, but its output [" + @@ -283,14 +571,14 @@ Status SelectRecomputeSubgraph(const Node& entry_node, } // Iterate all input nodes according to allowed input arg index of the entry node. - const auto& input_arg_indices = op_recompute_config_it->second.input_arg_indices; + const auto& igorable_input_arg_indices = GetIgnorableInputIndices(*curr_node, probe_level); for (auto it = curr_node->InputEdgesBegin(), end = curr_node->InputEdgesEnd(); it != end; ++it) { const Node::EdgeEnd& input_edge = *it; const auto& parent_node = input_edge.GetNode(); const auto parent_node_output_index = input_edge.GetSrcArgIndex(); const auto current_node_input_index = input_edge.GetDstArgIndex(); - if (std::find(input_arg_indices.begin(), input_arg_indices.end(), current_node_input_index) != - input_arg_indices.end()) { + if (std::find(igorable_input_arg_indices.begin(), igorable_input_arg_indices.end(), current_node_input_index) == + igorable_input_arg_indices.end()) { // If the tensor size is constant and very small (Now < 1M), we stop adding the input edge into queue. auto output_shape = parent_node.OutputDefs()[parent_node_output_index]->Shape(); if (output_shape) {