From 1a9193a056e4dea0b51cb3392189aeca62cb98db Mon Sep 17 00:00:00 2001 From: rui-ren Date: Fri, 26 Jan 2024 19:41:02 +0000 Subject: [PATCH] update the lintrunner --- .../core/optimizer/gather_slice_fusion.cc | 17 +++++++++++------ .../core/optimizer/graph_transformer_utils.cc | 2 +- .../optimizer/graph_transform_test_builder.cc | 2 +- .../core/optimizer/graph_transformer_utils.cc | 1 + 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc index 06dc1340797e6..5bdbcac80c4fa 100644 --- a/onnxruntime/core/optimizer/gather_slice_fusion.cc +++ b/onnxruntime/core/optimizer/gather_slice_fusion.cc @@ -8,6 +8,8 @@ namespace onnxruntime { +// Check valid Gather Ops of version {1, 11, 13} +// Add get the parameters of index, axis, indices_n_dims bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || @@ -41,6 +43,8 @@ bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& return true; } +// Check valid Slice Ops of version {1, 10, 11, 13} +// Add get the parameters of starts, ends, axes, steps bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& node, InlinedVector& starts, InlinedVector& ends, @@ -59,6 +63,7 @@ bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& } // If Slice op of opset version 1 + // Node inputs include: starts/ends if (onnx_opset_version == 1) { if (!graph_utils::GetRepeatedNodeAttributeValues(node, "starts", starts) || !graph_utils::GetRepeatedNodeAttributeValues(node, "ends", ends) || @@ -72,9 +77,8 @@ bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& } // If Slice op of opset version >= 10 + // Node inputs include: starts/ends/axes/steps if (onnx_opset_version >= 10) { - // node inputs include: starts - ends - axes - steps - // return a pointer to the corresponding NodeArg if input of the node at the index exists auto get_input_if_exists = [&node](size_t input_index) -> const NodeArg* { const auto& input_defs = node.InputDefs(); @@ -105,7 +109,7 @@ bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& return {}; }; - // starts and ends inputs have to exist, be constants and be of the same size. + // starts/ends/axes/steps inputs have to exist, be constants. const ONNX_NAMESPACE::TensorProto* starts_init = get_initializer_if_constant(1); const ONNX_NAMESPACE::TensorProto* ends_init = get_initializer_if_constant(2); const ONNX_NAMESPACE::TensorProto* axes_init = get_initializer_if_constant(3); @@ -115,6 +119,7 @@ bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& return false; } + // starts/ends/axes/steps inputs have to exist, be the same szie. starts = get_initializer_data(starts_init); ends = get_initializer_data(ends_init); axes = get_initializer_data(axes_init); @@ -148,7 +153,7 @@ GatherToSplitFusion is to fuse: Node |-> Gather(index=0, axis=axis) |-> Gather(index=1, axis=axis) - |-> Slice(index=2, axis=axis) + |-> Slice (index=2, axis=axis) To Node |-> Split(index=0) @@ -173,7 +178,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - // Currently only catch after Reshape ops, optimize in the future + // TODO: Currently only catch after Reshape ops, optimize in the future if (node.OpType() != "Reshape") continue; size_t output_count = node.GetOutputEdgesCount(); @@ -322,7 +327,7 @@ Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int gra {graph.GetNodeArg(node_arg->Name()), split_arg}, split_outputs); split_node.AddAttribute("axis", split_axis); - // to do here + split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); int onnx_opset_version = -1; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index dbfa5c1014bfb..4e939fe3c7b6b 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -309,7 +309,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - // transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.cc b/onnxruntime/test/optimizer/graph_transform_test_builder.cc index 1ad45d9d27aad..a5024f510b3cd 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.cc @@ -15,7 +15,7 @@ #include "test/util/include/inference_session_wrapper.h" // enable to dump model for debugging -#define SAVE_TEST_GRAPH 1 +#define SAVE_TEST_GRAPH 0 namespace onnxruntime { namespace test { diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 0b68dc65e41cd..4340069eee110 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -142,6 +142,7 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); + // If a model with Q, DQ nodes is being used for the purpose of training, it must be for // Quantization Aware Training. So, replace QDQ nodes with FakeQuant. transformers.emplace_back(std::make_unique(compatible_eps));