diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index 9f49db9f12b79..df81367c5bbee 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -245,16 +245,174 @@ static std::unique_ptr GetDQIfProducingValue(const api::GraphRef& : std::unique_ptr(); } +// Forward declarations. +static bool NormalizeAndValidateAxes(std::vector& axes, size_t rank); +static std::optional> ReadFromAttrOrInput(const api::GraphRef& graph, api::NodeRef& node, + std::string_view attr_name, size_t inp_index, + int64_t opset); +static int64_t UnsqueezeAxis(gsl::span positive_unsqueeze_axes, int64_t axis); + +// Quantization modes for QuantizeLinear/DequantizeLinear. +enum class QuantizationMode : uint8_t { + kUnknown, + kPerTensor, + kPerAxis, + kBlocked, +}; + +// Returns true if this optimizer supports the given quantization mode. +constexpr bool IsSupportedQuantizationMode(QuantizationMode mode) { + return (mode == QuantizationMode::kPerTensor) || (mode == QuantizationMode::kPerAxis); +} + +// Stores quantization info for a validated Q or DQ node. +struct QuantizationInfo { + QuantizationMode mode; + int64_t norm_axis; // 'normalized' axis is in the range [0, Input0Rank - 1] +}; + +// Returns quantization information (quantization mode, normalized axis). +// Returns std::nullopt if unable to get the quantization info or if the axis attribute is invalid. +static std::optional GetQuantizationInfo(const api::GraphRef& graph, + const api::NodeRef& q_or_dq_node) { + const std::vector inputs = q_or_dq_node.Inputs(); + + // Need to use the scale input's shape to determine the quantization mode. Can't just use the presence of the axis + // attribute because even per-tensor Q/DQ have a default axis of 1. + std::string_view scale_input = inputs[1]; + const std::unique_ptr scale_value_info = graph.GetValueInfo(scale_input); + std::optional> scale_shape = scale_value_info->Shape(); + if (!scale_shape) { + return std::nullopt; + } + + QuantizationInfo quant_info = {}; + + if (IsScalarOr1Element1DTensor(*scale_shape)) { + // A scalar or tensor scale with shape (1,) indicates per-tensor quantization. + quant_info.mode = QuantizationMode::kPerTensor; + quant_info.norm_axis = 1; // 1 is the default 'axis' even for per-tensor quantization (axis is ignored). + } else { + // This is either per-axis or blocked quantization. A non-zero block_size attribute indicates blocked quantization. + int64_t axis = q_or_dq_node.GetAttributeIntDefault("axis", 1); + const auto input0_info = graph.GetValueInfo(inputs[0]); + auto input0_rank = input0_info->ShapeRank(); + if (!input0_rank.has_value() || !NormalizeAndValidateAxis(axis, *input0_rank)) { + // Unable to normalize the DQ's axis. + // TODO(adrianlizarraga): Should look into a logging facility to make it easier to inspect issues. + return std::nullopt; + } + + int64_t block_size = q_or_dq_node.GetAttributeIntDefault("block_size", 0); + quant_info.norm_axis = axis; + quant_info.mode = block_size != 0 ? QuantizationMode::kBlocked : QuantizationMode::kPerAxis; + } + + return quant_info; +} + +/// +/// Represents a DequantizeLinear node that TransposeInputImpl() or UnsqueezeInput() can look past to transpose or +/// unsqueeze its input. +/// +class DQToLookPast { + public: + DQToLookPast(std::unique_ptr&& dq_node, QuantizationInfo quant_info) + : dq_node_(std::move(dq_node)), quant_info_(quant_info) { + assert(dq_node_ != nullptr); // Expect dq_node to be valid. + } + + DQToLookPast(DQToLookPast&& other) = default; + DQToLookPast& operator=(DQToLookPast&& other) = default; + + inline void DisconnectInput0() { + dq_node_->SetInput(0, ""); + } + + inline void ReconnectInput0(std::string_view input_name) { + dq_node_->SetInput(0, input_name); + } + + inline std::string_view GetInput0() const { + return dq_node_->Inputs()[0]; + } + + inline std::string_view GetOutput() const { + return dq_node_->Outputs()[0]; + } + + /// + /// Sets the DQ's new transposed input[0]. The DQ's axis and output shape are updated. + /// + /// graph + /// name of new transposed input[0] + /// inverse transpose permutation used to update the DQ's axis + void SetTransposedInput(api::GraphRef& graph, std::string_view new_input, gsl::span perm_inv) { + if (quant_info_.mode == QuantizationMode::kPerAxis) { + quant_info_.norm_axis = perm_inv[gsl::narrow_cast(quant_info_.norm_axis)]; + } + + this->SetUpdatedInput(graph, new_input); + } + + /// + /// Sets the DQ's new unsqueezed input[0]. The DQ's axis and output shape are updated. + /// The provided unsqueeze axes must all be positive (i.e., normalized). + /// + /// graph + /// name of new unsqueezed input[0] + /// positive unsqueeze axes used to update the DQ's axis + void SetUnsqueezedInput(api::GraphRef& graph, std::string_view new_input, + gsl::span positive_unsqueeze_axes) { + if (quant_info_.mode == QuantizationMode::kPerAxis) { + quant_info_.norm_axis = UnsqueezeAxis(positive_unsqueeze_axes, quant_info_.norm_axis); + } + + this->SetUpdatedInput(graph, new_input); + } + + /// + /// Static function that returns/moves the DQ node from a std::optional DQToLookPast. + /// The std::optional DQToLookPast is set to std::nullopt to prevent accidental reuse. + /// + /// The DQToLookPast to steal the node from + /// The dq_node unique_ptr moved from dq_to_look_past + static std::unique_ptr TakeDQNode(std::optional& dq_to_look_past) { + std::unique_ptr node; + + if (dq_to_look_past) { + node = std::move(dq_to_look_past->dq_node_); + dq_to_look_past = std::nullopt; + } + + return node; + } + + private: + // Called by SetTransposedInput() and SetUnsqueezedInput() to update the DQ's input, + // axis, and output shape. + void SetUpdatedInput(api::GraphRef& graph, std::string_view new_input) { + dq_node_->SetInput(0, new_input); + dq_node_->SetAttributeInt("axis", quant_info_.norm_axis); + auto new_shape = *graph.GetValueInfo(new_input)->Shape(); + graph.GetValueInfo(dq_node_->Outputs()[0])->SetShape(&new_shape); + } + + std::unique_ptr dq_node_; + QuantizationInfo quant_info_; +}; + /// /// Return a DequantizeLinear node if it's input is a constant initializer and it has a single consumer. /// In this case the initializer can be updated in-place by UnsqueezeInput or TransposeInput. /// /// Current graph /// Value to check if produced by a DQ node who's input is a constant initializer -/// NodeRef for DQ node if it meets the requirements. -static std::unique_ptr GetDQWithConstInitializerInputAndSingleConsumer(const api::GraphRef& graph, - std::string_view value_name) { - std::unique_ptr result; +/// DQToLookPast for DQ node if it meets the requirements, or std::nullopt otherwise. +static std::optional GetDQWithConstInitializerInputAndSingleConsumer(const api::GraphRef& graph, + std::string_view value_name) { + std::optional result; + std::optional quant_info; auto dq_node = GetDQIfProducingValue(graph, value_name); if (dq_node) { @@ -267,10 +425,10 @@ static std::unique_ptr GetDQWithConstInitializerInputAndSingleCons break; } - // For now keep it simple and don't support per-axis quantization as that would require updating the axis of - // the DQ node during TransposeInputImpl and UnsqueezeInput. - auto dq_scale = graph.GetConstant(dq_node->Inputs()[1]); - if (!dq_scale || dq_scale->NumElements() != 1) { + // Get the quantization mode (per-tensor, per-channel) and the normalized quantization axis. + // To keep things simple, do not support blocked quantization for now (added in opset 21). + quant_info = GetQuantizationInfo(graph, *dq_node); + if (!quant_info || !IsSupportedQuantizationMode(quant_info->mode)) { break; } @@ -285,20 +443,13 @@ static std::unique_ptr GetDQWithConstInitializerInputAndSingleCons break; } - result = std::move(dq_node); + result = DQToLookPast(std::move(dq_node), *quant_info); } while (false); } return result; } -// Forward declarations for utils used by MakeQDQNodeUnit -static bool NormalizeAndValidateAxes(std::vector& axes, size_t rank); -static std::optional> ReadFromAttrOrInput(const api::GraphRef& graph, api::NodeRef& node, - std::string_view attr_name, size_t inp_index, - int64_t opset); -static int64_t UnsqueezeAxis(gsl::span sorted_positive_unsqueeze_axes, int64_t axis); - /// /// Insert a Q -> DQ pair after the node following the DQ by using scale and zp info from the preceding DQ node. /// DQ -> next node => DQ -> next node -> Q -> DQ. @@ -324,34 +475,21 @@ static bool MakeQDQNodeUnit(api::GraphRef& graph, const api::NodeRef& dq_node) { const bool is_unsqueeze = next_node.OpType() == "Unsqueeze"; const auto scale_input = dq_inputs[1]; - const auto scale_value_info = graph.GetValueInfo(scale_input); std::optional zp_input; - std::optional> zp_value_info; - - auto scale_shape = scale_value_info->Shape(); - if (!scale_shape) { - // axis potentially needs updating due to the transpose or unsqueeze but we don't have the required info to do it. - return false; - } if (dq_inputs.size() > 2) { zp_input = dq_inputs[2]; - zp_value_info = graph.GetValueInfo(zp_input.value()); } - // DQ uses per-axis quantization if its scale input is not a scalar and not a tensor with shape (1,). - // Note there could be an axis value as the onnx spec says that is ignored for per-tensor quantization, - // so we have to check the scale input's shape. - const bool update_dq_axis = !IsScalarOr1Element1DTensor(*scale_shape); - int64_t axis = dq_node.GetAttributeIntDefault("axis", 1); + std::optional dq_quant_info = GetQuantizationInfo(graph, dq_node); + if (!dq_quant_info || !IsSupportedQuantizationMode(dq_quant_info->mode)) { + return false; // Can't get the quantization mode/axis or is a quantization mode that is not supported. + } - if (update_dq_axis) { - const auto dq_input0_info = graph.GetValueInfo(dq_inputs[0]); - auto dq_input0_rank = dq_input0_info->ShapeRank(); - if (!dq_input0_rank.has_value() || !NormalizeAndValidateAxis(axis, *dq_input0_rank)) { - return false; // Unable to normalize the DQ's axis. - } + int64_t axis = dq_quant_info->norm_axis; + // Have to update the axis for newly inserted Q/DQ after a Transpose or Unsqueeze if using per-axis quantization. + if (dq_quant_info->mode == QuantizationMode::kPerAxis) { if (is_transpose) { auto perm = GetPermAttrIfValid(next_node); assert(perm.has_value()); // onnx shape inferencing checks that `perm` is valid @@ -360,15 +498,19 @@ static bool MakeQDQNodeUnit(api::GraphRef& graph, const api::NodeRef& dq_node) { auto axes = ReadFromAttrOrInput(graph, next_node, "axes", /*inp_index*/ 1, /*opset*/ 13); assert(axes.has_value()); // 'axes' are required for Unsqueeze + const auto dq_output_info = graph.GetValueInfo(dq_node.Outputs()[0]); + std::optional dq_output_rank = dq_output_info->ShapeRank(); + if (!dq_output_rank.has_value()) { + return false; // Need to know the rank of the input to the Unsqueeze to normalize unsqueeze axes + } + // Normalize negative unsqueeze axes by adding output rank. - // Unsqueeze output rank = input_rank + axes.size() - // Unsqueeze's input rank is the same as the DQ's input[0] rank. - if (!NormalizeAndValidateAxes(*axes, *dq_input0_rank + axes->size())) { + // Unsqueeze output rank = unsqueeze input rank + axes.size() + if (!NormalizeAndValidateAxes(*axes, *dq_output_rank + axes->size())) { return false; } // Need to update axis if Unsqueeze inserts a 1 before the axis dim. - std::sort(axes->begin(), axes->end()); axis = UnsqueezeAxis(*axes, axis); } } @@ -567,7 +709,7 @@ static std::optional> ReadFromAttrOrInput(const api::GraphR } // Computes inverse permutation. Unsafe if perm is not a valid permutation. -std::vector InvertPerm(const std::vector& perm) { +std::vector InvertPerm(gsl::span perm) { size_t rank = perm.size(); std::vector perm_inv(rank); for (size_t i = 0; i < rank; ++i) { @@ -754,15 +896,18 @@ static std::vector SqueezePerm(const std::vector& axes, const } // Computes a new axis value for an unsqueezed version of a tensor. Incorrect if any axes -// values are negative, duplicated, or are not sorted in increasing order. +// values are negative or duplicated. // // Ex: axes = [0, 1, 2], axis = 0, new_axis = 3 // axes = [0, 1, 3], axis = 1, new_axis = 4 -static int64_t UnsqueezeAxis(gsl::span sorted_positive_unsqueeze_axes, int64_t axis) { +static int64_t UnsqueezeAxis(gsl::span positive_unsqueeze_axes, int64_t axis) { assert(axis >= 0); int64_t new_axis = axis; - for (int64_t unsqueeze_axis : sorted_positive_unsqueeze_axes) { + std::vector sorted_axes(positive_unsqueeze_axes.begin(), positive_unsqueeze_axes.end()); + std::sort(sorted_axes.begin(), sorted_axes.end()); + + for (int64_t unsqueeze_axis : sorted_axes) { if (unsqueeze_axis <= new_axis) { new_axis += 1; } @@ -812,12 +957,6 @@ static std::vector SortedAxesForTransposedInput(const std::vectorShape(); - graph.GetValueInfo(dq.Outputs()[0])->SetShape(&new_shape); -} - /////// /////// /////// /////// @@ -834,27 +973,27 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons std::unique_ptr constant = ctx.graph.GetLocalConstant(input); // allow a constant initializer coming via a DQ node with a single consumer - std::unique_ptr dq_node; + std::optional dq_to_look_past; std::string_view constant_dq_input; if (!constant) { // look past a DQ node for a constant initializer. essentially we pretend the DQ node doesn't exist // to enable directly making changes to the initializer. any nodes added for other consumers of the initializer // in 'Case 1' are prior to the DQ so we don't break up any QDQ node units. - dq_node = GetDQWithConstInitializerInputAndSingleConsumer(ctx.graph, input); - if (dq_node) { + dq_to_look_past = GetDQWithConstInitializerInputAndSingleConsumer(ctx.graph, input); + if (dq_to_look_past) { // underlying string for the input name is in the Node so it's safe to store in string_view constant_dq_input - constant_dq_input = dq_node->Inputs()[0]; + constant_dq_input = dq_to_look_past->GetInput0(); constant = ctx.graph.GetLocalConstant(constant_dq_input); // remove the DQ node as a consumer of the initializer while we modify things - dq_node->SetInput(0, ""); + dq_to_look_past->DisconnectInput0(); } } // Clear the input, which also removes this node's input as a consumer of the value. // NOTE: the node may have multiple inputs consuming the value. node.SetInput(i, ""); - auto value_to_modify = dq_node ? constant_dq_input : input; + auto value_to_modify = dq_to_look_past ? constant_dq_input : input; auto consumers = ctx.graph.GetValueConsumers(value_to_modify); // Case 1: input is a constant with a known list of consumer nodes @@ -873,8 +1012,8 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons auto new_shape = UnsqueezeShape(constant->Shape(), axes); ctx.graph.ReshapeInitializer(value_to_modify, new_shape); - if (dq_node) { - UpdateDQNodeInputAndShape(ctx.graph, *dq_node, constant_dq_input); + if (dq_to_look_past) { + dq_to_look_past->SetUnsqueezedInput(ctx.graph, constant_dq_input, axes); } node.SetInput(i, input); // restore the original connection @@ -886,10 +1025,14 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons // look past a DQ node for a Squeeze to cancel if (inp_node && inp_node->OpType() == "DequantizeLinear") { - dq_node = std::move(inp_node); - auto dq_input = dq_node->Inputs()[0]; - inp_node = ctx.graph.GetNodeProducingOutput(dq_input); - consumers = ctx.graph.GetValueConsumers(dq_input); + std::optional dq_quant_info = GetQuantizationInfo(ctx.graph, *inp_node); + + if (dq_quant_info && IsSupportedQuantizationMode(dq_quant_info->mode)) { + dq_to_look_past = std::make_optional(std::move(inp_node), *dq_quant_info); + auto dq_input = dq_to_look_past->GetInput0(); + inp_node = ctx.graph.GetNodeProducingOutput(dq_input); + consumers = ctx.graph.GetValueConsumers(dq_input); + } } if (inp_node != nullptr && inp_node->IsOp("Squeeze")) { @@ -897,9 +1040,9 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons std::optional> squeeze_axes = std::nullopt; squeeze_axes = ReadFromAttrOrInput(ctx, *inp_node, "axes", /*inp_index*/ 1, /*opset*/ 13); if (squeeze_axes != std::nullopt && *squeeze_axes == axes) { - if (dq_node) { - UpdateDQNodeInputAndShape(ctx.graph, *dq_node, inp_node_inputs[0]); - node.SetInput(i, dq_node->Outputs()[0]); + if (dq_to_look_past) { + dq_to_look_past->SetUnsqueezedInput(ctx.graph, inp_node_inputs[0], axes); + node.SetInput(i, dq_to_look_past->GetOutput()); } else { node.SetInput(i, inp_node_inputs[0]); } @@ -907,7 +1050,7 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons // Remove the Squeeze node if possible // if there's a DQ node the `consumers` list still includes it so allow for that. // in that case UpdateDQNodeInputAndShape already updated the input of the DQ node so it's safe to remove it. - if (consumers->comprehensive && consumers->nodes.size() == size_t(dq_node ? 1 : 0)) { + if (consumers->comprehensive && consumers->nodes.size() == size_t(dq_to_look_past ? 1 : 0)) { ctx.graph.RemoveNode(*inp_node); if (ctx.opset >= 13 && !ctx.graph.HasValueConsumers(inp_node_inputs[1])) { @@ -922,8 +1065,8 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons } // any DQ node special casing doesn't apply anymore, so go back to the original inp_node - if (dq_node) { - inp_node = std::move(dq_node); + if (dq_to_look_past) { + inp_node = DQToLookPast::TakeDQNode(dq_to_look_past); } // Case 3: Add an Unsqueeze node. @@ -989,20 +1132,20 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t std::unique_ptr constant = graph.GetLocalConstant(input); // allow a constant initializer coming via a DQ node with a single consumer - std::unique_ptr dq_node; + std::optional dq_to_look_past; std::string_view constant_dq_input; if (!constant) { // look past a DQ node for a constant initializer. essentially we pretend the DQ node doesn't exist // to enable directly making changes to the initializer. any nodes added for other consumers of the initializer // in 'Case 1' are prior to the DQ so we don't break up any QDQ node units. - dq_node = GetDQWithConstInitializerInputAndSingleConsumer(graph, input); - if (dq_node) { + dq_to_look_past = GetDQWithConstInitializerInputAndSingleConsumer(graph, input); + if (dq_to_look_past) { // underlying string for the input name is in the Node so it's safe to store in string_view constant_dq_input - constant_dq_input = dq_node->Inputs()[0]; + constant_dq_input = dq_to_look_past->GetInput0(); constant = graph.GetLocalConstant(constant_dq_input); // remove the DQ node as a consumer of the initializer while we modify things - dq_node->SetInput(0, ""); + dq_to_look_past->DisconnectInput0(); } } @@ -1010,7 +1153,7 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t // NOTE: the node may have multiple inputs consuming the value. node.SetInput(i, ""); - auto constant_to_modify = dq_node ? constant_dq_input : input; + auto constant_to_modify = dq_to_look_past ? constant_dq_input : input; auto consumers = graph.GetValueConsumers(constant_to_modify); // Case 1: input is a constant with a known list of consumer nodes @@ -1018,13 +1161,14 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t // we modify the initializer in-place and need to reconnect things up when we're done. this helper will // do that when it goes out of scope. if we have manually reconnected, input or constant_dq_input is // set to an empty string. - auto reconnect_nodes = gsl::finally([i, &node, &dq_node, &input, &constant_dq_input] { + auto reconnect_nodes = gsl::finally([i, &node, &dq_to_look_past, &input, &constant_dq_input] { if (!input.empty()) { node.SetInput(i, input); } if (!constant_dq_input.empty()) { - dq_node->SetInput(0, constant_dq_input); + assert(dq_to_look_past); + dq_to_look_past->ReconnectInput0(constant_dq_input); } }); @@ -1038,16 +1182,19 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t // Permute1DConstant permutes the constant and adds a new initializer. The old initializer is removed only if // there are no other consumers. if (constant->Shape().size() == 1 && constant->Shape()[0] == gsl::narrow_cast(perm.size())) { - auto& node_to_update = dq_node ? *dq_node : node; - Permute1DConstant(graph, node_to_update, *constant, i, constant_to_modify, perm); + // A quantized (roi/scales/sizes) input for Resize or a quantized pads input for Pad would be unlikely. + // Even if it occurred, HandleResize()/HandlePad() permute these kinds of inputs directly and do not try to + // call TransposeInput() on them. Also, WrapTransposesAroundNode() does not call TransposeInput() on non-const + // inputs of this kind. So, we should not have a DQ to look past at this point. + // + // In the event that we decide to handle a DQ in the future, note that the DQ axis should not be + // changed (remains 0), but all DQ inputs should be permuted. + assert(!dq_to_look_past); - // unset updated input so reconnect_nodes doesn't change it back - if (dq_node) { - constant_dq_input = ""; - } else { - input = ""; - } + Permute1DConstant(graph, node, *constant, i, constant_to_modify, perm); + // unset updated input so reconnect_nodes doesn't change it back + input = ""; return; } @@ -1063,8 +1210,8 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t graph.TransposeInitializer(constant_to_modify, perm); - if (dq_node) { - UpdateDQNodeInputAndShape(graph, *dq_node, constant_to_modify); + if (dq_to_look_past) { + dq_to_look_past->SetTransposedInput(graph, constant_to_modify, perm_inv); constant_dq_input = ""; // DQ input was already updated so we don't need reconnect_nodes to handle it } @@ -1076,10 +1223,14 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t // Look past a DQ for the Transpose if (inp_node && inp_node->OpType() == "DequantizeLinear") { - dq_node = std::move(inp_node); - auto dq_input = dq_node->Inputs()[0]; - inp_node = graph.GetNodeProducingOutput(dq_input); - consumers = graph.GetValueConsumers(dq_input); + std::optional dq_quant_info = GetQuantizationInfo(graph, *inp_node); + + if (dq_quant_info && IsSupportedQuantizationMode(dq_quant_info->mode)) { + dq_to_look_past = std::make_optional(std::move(inp_node), *dq_quant_info); + std::string_view dq_input = dq_to_look_past->GetInput0(); + inp_node = graph.GetNodeProducingOutput(dq_input); + consumers = graph.GetValueConsumers(dq_input); + } } if (inp_node != nullptr && inp_node->IsOp("Transpose")) { @@ -1089,9 +1240,9 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t if (*perm2 == perm_inv) { std::string_view pre_transpose_value = inp_node->Inputs()[0]; - if (dq_node) { - UpdateDQNodeInputAndShape(graph, *dq_node, pre_transpose_value); - node.SetInput(i, dq_node->Outputs()[0]); + if (dq_to_look_past) { + dq_to_look_past->SetTransposedInput(graph, pre_transpose_value, perm_inv); + node.SetInput(i, dq_to_look_past->GetOutput()); } else { node.SetInput(i, pre_transpose_value); } @@ -1099,14 +1250,14 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t // Remove the Transpose node if possible // if there's a DQ node the `consumers` list still includes it so allow for that. // in that case UpdateDQNodeInputAndShape already updated the input of the DQ node so it's safe to remove it. - if (consumers->comprehensive && consumers->nodes.size() == size_t(dq_node ? 1 : 0)) { + if (consumers->comprehensive && consumers->nodes.size() == size_t(dq_to_look_past ? 1 : 0)) { graph.RemoveNode(*inp_node); } return; } - if (!dq_node) { + if (!dq_to_look_past) { // Otherwise, compose the perm and Transpose pre_transpose_value. Cost is the same and we may be able to remove // the other Transpose. const std::vector& perm_combined = ComposePerm(*perm2, perm); @@ -1130,8 +1281,8 @@ static void TransposeInputImpl(api::GraphRef& graph, api::NodeRef& node, size_t } // any DQ node special casing doesn't apply anymore, so go back to the original inp_node - if (dq_node) { - inp_node = std::move(dq_node); + if (dq_to_look_past) { + inp_node = DQToLookPast::TakeDQNode(dq_to_look_past); consumers = graph.GetValueConsumers(input); } @@ -1301,8 +1452,8 @@ static bool CanLikelyRemoveTranspose(const api::GraphRef& graph, api::NodeRef& t // - the value is a constant initializer // - the value is the output of a DQ node who's input is a constant initializer // - UnsqueezeInput/TransposeInput can look past the DQ to update the constant initializer directly -// - DQ node is currently ignored if it uses per-channel quantization -// - supporting per-channel quantization requires modifying the scales and zero point data, which can be done +// - DQ node is currently ignored if it uses blocked quantization (per-tensor and per-axis are supported). +// - supporting blocked quantization requires modifying the scales and zero point data, which can be done // if/when there's a use-case to justify the development cost. // - the input was originally connected to a shared constant initializer that was updated in place by UnsqueezeInput // or TransposeInput, and usage by this node had Squeeze/Transpose nodes inserted to counteract the effect of the @@ -1321,8 +1472,8 @@ static bool IsConstant(const api::GraphRef& graph, std::string_view value_name) // look past a DQ node if (producer_node->OpType() == "DequantizeLinear") { - std::unique_ptr dq_node = GetDQWithConstInitializerInputAndSingleConsumer(graph, value_name); - if (dq_node != nullptr) { + std::optional dq_to_look_past = GetDQWithConstInitializerInputAndSingleConsumer(graph, value_name); + if (dq_to_look_past) { // DQ node pointing to an constant initializer return true; } @@ -1363,7 +1514,7 @@ static int EstimateTransposeValueCost(const api::GraphRef& graph, std::string_vi if (dq_input_node != nullptr) { if (dq_input_node->OpType() == "Squeeze") { auto squeeze_input_node = graph.GetNodeProducingOutput(dq_input_node->Inputs()[0]); - if (squeeze_input_node->OpType() == "Transpose") { + if (squeeze_input_node != nullptr && squeeze_input_node->OpType() == "Transpose") { // we only want to set this if it is a Transpose as otherwise we're invalidating the cost given it is // rank based and the Squeeze will change that. producer_node = std::move(squeeze_input_node); @@ -2733,37 +2884,23 @@ static bool TryFixTransposeMissingDQ(OptimizerCtx& ctx, api::NodeRef& transpose_ const auto q_domain = q_node.Domain(); const auto scale_input = q_node_inputs[1]; - const auto scale_value_info = ctx.graph.GetValueInfo(scale_input); std::optional zp_input; - std::optional> zp_value_info; - - auto scale_shape = scale_value_info->Shape(); - if (!scale_shape) { - // Axis potentially needs updating due to the transpose but we don't have the required info to do it. - return false; - } if (q_node_inputs.size() > 2) { zp_input = q_node_inputs[2]; - zp_value_info = ctx.graph.GetValueInfo(zp_input.value()); } - // Q uses per-axis quantization if its scale input is not a scalar and not a tensor with shape (1,). - // Note there could be an axis value as the onnx spec says that is ignored for per-tensor quantization, - // so we have to check the scale input's shape. - const bool update_axis = !IsScalarOr1Element1DTensor(*scale_shape); - int64_t axis = q_node.GetAttributeIntDefault("axis", 1); + std::optional q_quant_info = GetQuantizationInfo(ctx.graph, q_node); + if (!q_quant_info || !IsSupportedQuantizationMode(q_quant_info->mode)) { + return false; // Can't get quantization mode/axis or is a quantization mode that is not supported. + } - if (update_axis) { - auto perm = GetPermAttrIfValid(transpose_node); - assert(perm.has_value()); // onnx shape inferencing checks that `perm` is valid - - const auto q_input0_info = ctx.graph.GetValueInfo(q_node_inputs[0]); - std::optional q_input0_rank = q_input0_info->ShapeRank(); - if (!q_input0_rank.has_value() || !NormalizeAndValidateAxis(axis, *q_input0_rank)) { - return false; // Unable to normalize the Q's axis. - } + int64_t axis = q_quant_info->norm_axis; + if (q_quant_info->mode == QuantizationMode::kPerAxis) { + // Have to update the axis for newly inserted Q/DQ before a Transpose if using per-channel quantization. + auto perm = GetPermAttrIfValid(transpose_node); + assert(perm.has_value()); // onnx shape inferencing checks that `perm` is valid axis = (*perm)[gsl::narrow_cast(axis)]; // Note: do not invert permutation. } diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h index 6d1f1f8535ba4..0095ead75f0c8 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include @@ -59,7 +60,7 @@ struct OptimizerCtx { /// {0} inline std::vector FirstInput(OptimizerCtx&, api::NodeRef&) { return {0}; } -std::vector InvertPerm(const std::vector& perm); +std::vector InvertPerm(gsl::span perm); // Transpose all inputs and all outputs bool HandleSimpleNode(HandlerArgs& args); diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index c253d73a55441..8d90e48db97c1 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -4848,20 +4848,135 @@ TEST(TransposeOptimizerTests, ConstantFoldTransposeAndSqueezeOutputCorrectness) testing::ContainerEq(fetches[1].Get().DataAsSpan())); } -// Tests the fix-up of a QDQ NodeUnit containing a per-channel DQ followed by an Unsqueeze. -// Before: DQ (axis = 0) -> Unsqueeze (axes = [0, 1, 2]) -> Op -// After: DQ (axis = 0) -> Unsqueeze (axes = [0, 1, 2]) -> Q (axis = 3) -> DQ (axis = 3) -> Op -TEST(TransposeOptimizerTests, FixQDQNodeUnitWithPerChannelDQUnsqueeze) { - // Test model contains a Mul with a broadcastable/constant/per-channel DQ input. When a transpose is pushed through - // the Mul, the contant DQ input is Unsqueezed. - auto model_uri = ORT_TSTR("testdata/transpose_optimization_unsqueeze_dq_axis.qdq.onnx"); +// Utility to get the axis attribute for a Q or DQ node. +static void GetQOrDQAxis(const Node& q_or_dq_node, /*out*/ int64_t& axis) { + const NodeAttributes& attrs = q_or_dq_node.GetAttributes(); + auto axis_attr_it = attrs.find("axis"); + axis = 1; + + if (axis_attr_it != attrs.end()) { + auto axis_attr = axis_attr_it->second; + ASSERT_TRUE(axis_attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + axis = axis_attr.i(); + } +} + +// Tests the fix-up of a QDQ NodeUnit containing a per-axis DQ followed by an Unsqueeze and Transpose. +// Before: DQ (axis = 0) -> Unsqueeze (axes = [0, 1, 2]) -> Transpose (perm = [0, 3, 1, 2]) -> Op +// After: DQ (axis = 0) -> Unsqueeze -> Q(axis = 3) -> DQ(axis = 3) -> Transpose -> Q(axis = 1) -> DQ(axis = 1) -> Op +TEST(TransposeOptimizerTests, FixQDQNodeUnitWithPerAxisDQUnsqueezeTranspose) { + // Model contains a Mul with a broadcastable/per-axis DQ input[1]. When a transpose is pushed through + // the Mul's input[0], input[1]'s input is unsqueezed and transposed. + auto model_uri = ORT_TSTR("testdata/transpose_optimizer_qdq_fixup_unsqueeze_per_axis_dq.onnx"); RandomValueGenerator random{123}; std::vector input_dims{1, 3, 4, 4}; std::vector input0_data = random.Gaussian(input_dims, 0.0f, 1.0f); + std::vector input1_data = {0, 1, 2}; + auto allocators = TestCPUExecutionProvider()->CreatePreferredAllocators(); OrtValue input0; - CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], input_dims, input0_data, &input0); + OrtValue input1; + CreateMLValue(allocators[0], input_dims, input0_data, &input0); + CreateMLValue(allocators[0], {3}, input1_data, &input1); + + NameMLValMap feeds{{"input0", input0}, {"input1", input1}}; + + std::vector output_names{"output0"}; + std::vector fetches_orig; + std::vector fetches; + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); + so.graph_optimization_level = TransformerLevel::Default; // off + + // get results with no modifications to the model + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig)); + } + + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + + Graph& graph = session.GetMutableGraph(); + CPUAllocator allocator; + + namespace alias_oto = onnx_transpose_optimization; + auto api_graph = MakeApiGraph(graph, + TestCPUExecutionProvider()->CreatePreferredAllocators()[0], + /*new_node_ep*/ nullptr); + + alias_oto::OptimizeResult result = alias_oto::Optimize(*api_graph); + ASSERT_EQ(result.error_msg, std::nullopt); + ASSERT_TRUE(result.graph_modified); + ASSERT_TRUE(graph.GraphResolveNeeded()); + ASSERT_STATUS_OK(graph.Resolve()); + + // Use this hack to save model for viewing if needed + // ASSERT_STATUS_OK(Model::Save(const_cast(session.GetModel()), + // ToPathString("transpose_optimizer_qdq_fixup_unsqueeze_per_axis_dq.debug.onnx"))); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Unsqueeze"], 1) << "1 Unsqueeze node added to broadcastable Mul weight."; + EXPECT_EQ(op_to_count["Transpose"], 1) << "2 Transposes at the I/O cancel. 1 Transpose inserted above Mul weight."; + + // Get the Unsqueeze and Transpose nodes. + Node* unsqueeze_node = nullptr; + Node* transpose_node = nullptr; + for (auto& node : graph.Nodes()) { + const std::string& op_type = node.OpType(); + if (op_type == "Unsqueeze") { + unsqueeze_node = &node; + } else if (op_type == "Transpose") { + transpose_node = &node; + } + } + + // DQ axis starts as 0 + ASSERT_TRUE(unsqueeze_node != nullptr); + const auto& dq_before_unsqueeze = *(unsqueeze_node->InputNodesBegin()); + int64_t dq_before_unsqueeze_axis = 1; + GetQOrDQAxis(dq_before_unsqueeze, dq_before_unsqueeze_axis); + EXPECT_EQ(dq_before_unsqueeze_axis, 0); + + // Axis changes to 3 after Unsqueeze + const auto& q_after_unsqueeze = *(unsqueeze_node->OutputNodesBegin()); + int64_t q_after_unsqueeze_axis = 1; + GetQOrDQAxis(q_after_unsqueeze, q_after_unsqueeze_axis); + EXPECT_EQ(q_after_unsqueeze_axis, 3); + + // Axis changes to 1 after Transpose + ASSERT_TRUE(transpose_node != nullptr); + const auto& q_after_transpose = *(transpose_node->OutputNodesBegin()); + int64_t q_after_transpose_axis = 1; + GetQOrDQAxis(q_after_transpose, q_after_transpose_axis); + EXPECT_EQ(q_after_transpose_axis, 1); + + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches)); + } + + ASSERT_THAT(fetches_orig[0].Get().DataAsSpan(), + testing::ContainerEq(fetches[0].Get().DataAsSpan())); +} + +// Tests the in-place unsqueeze and transpose of a constant consumed by a per-axis DQ. +TEST(TransposeOptimizerTests, InPlaceUnsqueezeTransposePerAxisDQ) { + // Model contains a Mul with a constant/broadcastable/per-axis DQ input[1]. + // When a transpose is pushed through the Mul's input[0], input[1]'s input is unsqueezed and transposed in-place. + auto model_uri = ORT_TSTR("testdata/transpose_optimizer_in_place_transpose_unsqueeze_per_axis_dq.onnx"); + + RandomValueGenerator random{123}; + std::vector input_dims{1, 3, 4, 4}; + std::vector input0_data = random.Gaussian(input_dims, 0.0f, 1.0f); + + auto allocators = TestCPUExecutionProvider()->CreatePreferredAllocators(); + OrtValue input0; + CreateMLValue(allocators[0], input_dims, input0_data, &input0); NameMLValMap feeds{{"input0", input0}}; @@ -4885,7 +5000,6 @@ TEST(TransposeOptimizerTests, FixQDQNodeUnitWithPerChannelDQUnsqueeze) { InferenceSessionWrapper session{so, GetEnvironment()}; ASSERT_STATUS_OK(session.Load(model_uri)); - // We call the ONNX transpose optimizer directly to use a custom cost check function. Graph& graph = session.GetMutableGraph(); CPUAllocator allocator; @@ -4894,22 +5008,161 @@ TEST(TransposeOptimizerTests, FixQDQNodeUnitWithPerChannelDQUnsqueeze) { TestCPUExecutionProvider()->CreatePreferredAllocators()[0], /*new_node_ep*/ nullptr); - // Use a custom optimization cost check that aggressively pushes channel-last or channel-first transposes. - auto custom_cost_fn = - [](const alias_oto::api::GraphRef& /* graph */, - const alias_oto::api::NodeRef& /* node */, - const std::vector& perm, - const std::unordered_set& /* outputs_leading_to_transpose */) -> alias_oto::CostCheckResult { - if (perm == alias_oto::ChannelFirstToLastPerm(perm.size()) || - perm == alias_oto::ChannelLastToFirstPerm(perm.size())) { - return alias_oto::CostCheckResult::kPushTranspose; + alias_oto::OptimizeResult result = alias_oto::Optimize(*api_graph); + ASSERT_EQ(result.error_msg, std::nullopt); + ASSERT_TRUE(result.graph_modified); + ASSERT_TRUE(graph.GraphResolveNeeded()); + ASSERT_STATUS_OK(graph.Resolve()); + + // Use this hack to save model for viewing if needed + // ASSERT_STATUS_OK(Model::Save(const_cast(session.GetModel()), + // ToPathString("updated_model_inplace_peraxis.onnx"))); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Unsqueeze"], 0) << "per-axis DQ constant was unsqueezed in-place."; + EXPECT_EQ(op_to_count["Transpose"], 0) << "2 pre-existing Transposes at the I/O cancel. " + << "per-axis DQ constant was transposed in-place."; + + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches)); + } + + ASSERT_THAT(fetches_orig[0].Get().DataAsSpan(), + testing::ContainerEq(fetches[0].Get().DataAsSpan())); +} + +// Tests the canceling of a pre-existing Transpose before a per-axis DQ during call to TransposeInputImpl. +// Before: input1 -> Transpose(perm = [0, 2, 3, 1]) -> DQ (axis = -1) -> Mul +// After : input1 -> DQ (axis = 1) -> Mul +TEST(TransposeOptimizerTests, CancelTransposeBeforePerAxisDQ) { + auto model_uri = ORT_TSTR("testdata/transpose_optimizer_cancel_transpose_per_axis_dq.onnx"); + + RandomValueGenerator random{123}; + std::vector input_dims{1, 3, 4, 4}; + std::vector input0_data = random.Gaussian(input_dims, 0.0f, 1.0f); + std::vector input1_data = {0, 1, 2}; + + auto allocators = TestCPUExecutionProvider()->CreatePreferredAllocators(); + OrtValue input0; + OrtValue input1; + CreateMLValue(allocators[0], input_dims, input0_data, &input0); + CreateMLValue(allocators[0], {1, 3, 1, 1}, input1_data, &input1); + + NameMLValMap feeds{{"input0", input0}, {"input1", input1}}; + + std::vector output_names{"output0"}; + std::vector fetches_orig; + std::vector fetches; + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); + so.graph_optimization_level = TransformerLevel::Default; // off + + // get results with no modifications to the model + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig)); + } + + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + + Graph& graph = session.GetMutableGraph(); + CPUAllocator allocator; + + namespace alias_oto = onnx_transpose_optimization; + auto api_graph = MakeApiGraph(graph, + TestCPUExecutionProvider()->CreatePreferredAllocators()[0], + /*new_node_ep*/ nullptr); + + alias_oto::OptimizeResult result = alias_oto::Optimize(*api_graph); + ASSERT_EQ(result.error_msg, std::nullopt); + ASSERT_TRUE(result.graph_modified); + ASSERT_TRUE(graph.GraphResolveNeeded()); + ASSERT_STATUS_OK(graph.Resolve()); + + // Use this hack to save model for viewing if needed + // ASSERT_STATUS_OK(Model::Save(const_cast(session.GetModel()), + // ToPathString("updated_model_peraxis_transpose_cancel.onnx"))); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Transpose"], 0) << "2 Transposes at the I/O cancel. " + << "Transpose inserted above Mul weight cancels."; + + // Get the DQ above Mul's input[1] + Node* dq_node = nullptr; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "DequantizeLinear" && node.Name() == "dq_mul_input_1") { + dq_node = &node; + break; } + } - return alias_oto::CostCheckResult::kFallThrough; - }; + // DQ axis changed from -1 (3) to 1 due to Tranpose above DQ being canceled. + ASSERT_TRUE(dq_node != nullptr); + int64_t dq_axis = 1; + GetQOrDQAxis(*dq_node, dq_axis); + EXPECT_EQ(dq_axis, 1); - alias_oto::OptimizeResult result = alias_oto::Optimize(*api_graph, /*provider_type*/ "", custom_cost_fn); + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches)); + } + + ASSERT_THAT(fetches_orig[0].Get().DataAsSpan(), + testing::ContainerEq(fetches[0].Get().DataAsSpan())); +} + +// Tests the canceling of a pre-existing Squeeze before a per-axis DQ during call to UnsqueezeInput. +// Before: input1 (shape = [1, 1, 1, 3]) -> Squeeze(axes = [0, 1, 2]) -> DQ (axis = 0) -> Mul +// After : input1 -> DQ (axis = 3) -> Mul +TEST(TransposeOptimizerTests, CancelSqueezeBeforePerAxisDQ) { + auto model_uri = ORT_TSTR("testdata/transpose_optimizer_cancel_squeeze_per_axis_dq.onnx"); + + RandomValueGenerator random{123}; + std::vector input_dims{1, 3, 4, 4}; + std::vector input0_data = random.Gaussian(input_dims, 0.0f, 1.0f); + std::vector input1_data = {0, 1, 2}; + + auto allocators = TestCPUExecutionProvider()->CreatePreferredAllocators(); + OrtValue input0; + OrtValue input1; + CreateMLValue(allocators[0], input_dims, input0_data, &input0); + CreateMLValue(allocators[0], {1, 1, 1, 3}, input1_data, &input1); + + NameMLValMap feeds{{"input0", input0}, {"input1", input1}}; + + std::vector output_names{"output0"}; + std::vector fetches_orig; + std::vector fetches; + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); + so.graph_optimization_level = TransformerLevel::Default; // off + + // get results with no modifications to the model + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig)); + } + + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + + Graph& graph = session.GetMutableGraph(); + CPUAllocator allocator; + + namespace alias_oto = onnx_transpose_optimization; + auto api_graph = MakeApiGraph(graph, + TestCPUExecutionProvider()->CreatePreferredAllocators()[0], + /*new_node_ep*/ nullptr); + alias_oto::OptimizeResult result = alias_oto::Optimize(*api_graph); ASSERT_EQ(result.error_msg, std::nullopt); ASSERT_TRUE(result.graph_modified); ASSERT_TRUE(graph.GraphResolveNeeded()); @@ -4917,11 +5170,26 @@ TEST(TransposeOptimizerTests, FixQDQNodeUnitWithPerChannelDQUnsqueeze) { // Use this hack to save model for viewing if needed // ASSERT_STATUS_OK(Model::Save(const_cast(session.GetModel()), - // ToPathString("transpose_optimization_unsqueeze_dq_axis.qdq.updated.onnx"))); + // ToPathString("updated_model_peraxis_squeeze_cancel.onnx"))); std::map op_to_count = CountOpsInGraph(graph); - EXPECT_EQ(op_to_count["Unsqueeze"], 1) << "1 Unsqueeze node added to broadcastable Mul weight."; - EXPECT_EQ(op_to_count["Transpose"], 1) << "2 Transposes at the I/O cancel. 1 Transpose inserted above Mul weight."; + EXPECT_EQ(op_to_count["Squeeze"], 0) << "Canceled by unsqueezed input consumed by per-axis DQ"; + EXPECT_EQ(op_to_count["Unsqueeze"], 0) << "No Unsqueeze inserted because it cancels with pre-existing Squeeze."; + + // Get the DQ above Mul's input[1] + Node* dq_node = nullptr; + for (auto& node : graph.Nodes()) { + if (node.OpType() == "DequantizeLinear" && node.Name() == "dq_mul_input_1") { + dq_node = &node; + break; + } + } + + // DQ axis changed from 0 to 3 due to Squeeze above DQ being canceled. + ASSERT_TRUE(dq_node != nullptr); + int64_t dq_axis = 1; + GetQOrDQAxis(*dq_node, dq_axis); + EXPECT_EQ(dq_axis, 3); ASSERT_STATUS_OK(session.Initialize()); ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches)); diff --git a/onnxruntime/test/testdata/make_transpose_optimization_unsqueeze_dq_axis.py b/onnxruntime/test/testdata/make_transpose_optimization_unsqueeze_dq_axis.py deleted file mode 100644 index 6d542e52b8bda..0000000000000 --- a/onnxruntime/test/testdata/make_transpose_optimization_unsqueeze_dq_axis.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import numpy as np -import onnx - -if __name__ == "__main__": - """ - Creates a QDQ model with a per-channel DQ weight that is Unsqueezed and Transposed by the Transpose optimizer. - """ - input0_shape = (1, 3, 4, 4) - - input0 = onnx.helper.make_tensor_value_info("input0", onnx.TensorProto.FLOAT, input0_shape) - output0 = onnx.helper.make_tensor_value_info("output0", onnx.TensorProto.FLOAT, None) - - scale_1 = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "scale_1") - zp_128 = onnx.numpy_helper.from_array(np.array(128, dtype=np.uint8), "zp_128") - scale_inv_255 = onnx.numpy_helper.from_array(np.array(1.0 / 255.0, dtype=np.float32), "scale_inv_255") - zp_0 = onnx.numpy_helper.from_array(np.array(0, dtype=np.uint8), "zp_0") - - mul_weight_i8_data = np.array([1, 2, 3], dtype=np.int8) - mul_weight_scales_data = np.array([1.0, 1.0, 1.0], dtype=np.float32) - mul_weight_zps_data = np.array([0, 0, 0], dtype=np.int8) - mul_weight = onnx.numpy_helper.from_array(mul_weight_i8_data, "mul_weight") - mul_weight_scales = onnx.numpy_helper.from_array(mul_weight_scales_data, "mul_weight_scales") - mul_weight_zps = onnx.numpy_helper.from_array(mul_weight_zps_data, "mul_weight_zps") - - # Transpose to channel-last - tp0_node = onnx.helper.make_node("Transpose", ["input0"], ["tp0_out"], name="tp0_node", perm=(0, 2, 3, 1)) - - # Q_0 - q0_node = onnx.helper.make_node("QuantizeLinear", ["tp0_out", "scale_1", "zp_128"], ["q0_out"], name="q0_node") - - # DQ_0 - dq0_node = onnx.helper.make_node("DequantizeLinear", ["q0_out", "scale_1", "zp_128"], ["dq0_out"], name="dq0_node") - - # Sigmoid - sigmoid_node = onnx.helper.make_node("Sigmoid", ["dq0_out"], ["sigmoid_out"], name="sigmoid_node") - - # Q_1 - q1_node = onnx.helper.make_node( - "QuantizeLinear", ["sigmoid_out", "scale_inv_255", "zp_0"], ["q1_out"], name="q1_node" - ) - - # DQ_1 - dq1_node = onnx.helper.make_node( - "DequantizeLinear", ["q1_out", "scale_inv_255", "zp_0"], ["dq1_out"], name="dq1_node" - ) - - # DQ_weight - dq_weight_node = onnx.helper.make_node( - "DequantizeLinear", - ["mul_weight", "mul_weight_scales", "mul_weight_zps"], - ["dq_weight_out"], - name="dq_weight_node", - axis=0, - ) - - # Mul - mul_node = onnx.helper.make_node("Mul", ["dq1_out", "dq_weight_out"], ["mul_out"], name="mul_node") - - # Q_2 - q2_node = onnx.helper.make_node("QuantizeLinear", ["mul_out", "scale_inv_255", "zp_0"], ["q2_out"], name="q2_node") - - # DQ_2 - dq2_node = onnx.helper.make_node( - "DequantizeLinear", ["q2_out", "scale_inv_255", "zp_0"], ["dq2_out"], name="dq2_node" - ) - - # Transpose to channel-first - tp1_node = onnx.helper.make_node("Transpose", ["dq2_out"], ["output0"], name="tp1_node", perm=(0, 3, 1, 2)) - - graph = onnx.helper.make_graph( - [ - tp0_node, - q0_node, - dq0_node, - sigmoid_node, - q1_node, - dq1_node, - dq_weight_node, - mul_node, - q2_node, - dq2_node, - tp1_node, - ], - "transpose_opt_unsqueeze_dq_axis", - [input0], - [output0], - initializer=[scale_1, zp_128, scale_inv_255, zp_0, mul_weight, mul_weight_scales, mul_weight_zps], - ) - opset_imports = [ - onnx.helper.make_opsetid("", 19), - ] - qdq_model = onnx.helper.make_model(graph, opset_imports=opset_imports) - - print("[INFO]: Running onnx.checker on qdq model") - qdq_model = onnx.shape_inference.infer_shapes(qdq_model) - onnx.checker.check_model(qdq_model, True) - qdq_model_path = "transpose_optimization_unsqueeze_dq_axis.qdq.onnx" - - print(f"[INFO]: Saving {qdq_model_path}") - onnx.save_model(qdq_model, qdq_model_path) diff --git a/onnxruntime/test/testdata/make_transpose_optimizer_per_axis_qdq_models.py b/onnxruntime/test/testdata/make_transpose_optimizer_per_axis_qdq_models.py new file mode 100644 index 0000000000000..b446b60686043 --- /dev/null +++ b/onnxruntime/test/testdata/make_transpose_optimizer_per_axis_qdq_models.py @@ -0,0 +1,207 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import numpy as np +import onnx + + +def subgraph_1d_const_input_dq(inputs, initializers, nodes) -> str: + """ + Creates mul_weight -> DQ. mul_weight is a constant of rank 1. + """ + mul_weight_i8_data = np.array([1, 2, 3], dtype=np.int8) + mul_weight = onnx.numpy_helper.from_array(mul_weight_i8_data, "mul_weight") + initializers.append(mul_weight) + + dq_output_name = "mul_input_1" + nodes.append( + onnx.helper.make_node( + "DequantizeLinear", + ["mul_weight", "mul_weight_scales", "mul_weight_zps"], + [dq_output_name], + name="dq_mul_input_1", + axis=0, + ) + ) + + return dq_output_name + + +def subgraph_1d_input_dq(inputs, initializers, nodes) -> str: + """ + Creates input1 -> DQ. input1 is a graph input of rank 1. + """ + input1_shape = (3,) + inputs.append(onnx.helper.make_tensor_value_info("input1", onnx.TensorProto.INT8, input1_shape)) + + dq_output_name = "mul_input_1" + nodes.append( + onnx.helper.make_node( + "DequantizeLinear", + ["input1", "mul_weight_scales", "mul_weight_zps"], + [dq_output_name], + name="dq_mul_input_1", + axis=0, + ) + ) + + return dq_output_name + + +def subgraph_4d_input_squeeze_dq(inputs, initializers, nodes) -> str: + """ + Creates input1 -> Squeeze -> DQ. input1 is a graph input of rank 4. + """ + input1_shape = (1, 1, 1, 3) + inputs.append(onnx.helper.make_tensor_value_info("input1", onnx.TensorProto.INT8, input1_shape)) + + axes_data = np.array([0, 1, 2], dtype=np.int64) + initializers.append(onnx.numpy_helper.from_array(axes_data, "axes_const")) + nodes.append(onnx.helper.make_node("Squeeze", ["input1", "axes_const"], ["squeeze_out"], name="squeeze_node")) + + dq_output_name = "mul_input_1" + nodes.append( + onnx.helper.make_node( + "DequantizeLinear", + ["squeeze_out", "mul_weight_scales", "mul_weight_zps"], + [dq_output_name], + name="dq_mul_input_1", + axis=0, + ) + ) + + return dq_output_name + + +def subgraph_4d_input_transpose_dq(inputs, initializers, nodes) -> str: + """ + Creates input1 -> Transpose -> DQ. input1 is a graph input of rank 4. + """ + input1_shape = (1, 3, 1, 1) + inputs.append(onnx.helper.make_tensor_value_info("input1", onnx.TensorProto.INT8, input1_shape)) + + perm = [0, 2, 3, 1] # To channel-last + nodes.append(onnx.helper.make_node("Transpose", ["input1"], ["tp_out_"], perm=perm, name="transpose_")) + + dq_output_name = "mul_input_1" + nodes.append( + onnx.helper.make_node( + "DequantizeLinear", + ["tp_out_", "mul_weight_scales", "mul_weight_zps"], + [dq_output_name], + name="dq_mul_input_1", + axis=-1, + ) + ) + + return dq_output_name + + +def make_model(model_path: str, build_mul_input_1_subgraph): + """ + Creates a QDQ model with a per-axis DQ input that is Unsqueezed and Transposed by the Transpose optimizer. + """ + input0_shape = (1, 3, 4, 4) + + inputs = [onnx.helper.make_tensor_value_info("input0", onnx.TensorProto.FLOAT, input0_shape)] + outputs = [onnx.helper.make_tensor_value_info("output0", onnx.TensorProto.FLOAT, None)] + + mul_weight_scales_data = np.array([1.0, 1.0, 1.0], dtype=np.float32) + mul_weight_zps_data = np.array([0, 0, 0], dtype=np.int8) + + initializers = [ + onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "scale_1"), + onnx.numpy_helper.from_array(np.array(128, dtype=np.uint8), "zp_128"), + onnx.numpy_helper.from_array(np.array(1.0 / 255.0, dtype=np.float32), "scale_inv_255"), + onnx.numpy_helper.from_array(np.array(0, dtype=np.uint8), "zp_0"), + onnx.numpy_helper.from_array(mul_weight_scales_data, "mul_weight_scales"), + onnx.numpy_helper.from_array(mul_weight_zps_data, "mul_weight_zps"), + ] + nodes = [] + + # Transpose to channel-last + tp0_node = onnx.helper.make_node("Transpose", ["input0"], ["tp0_out"], name="tp0_node", perm=(0, 2, 3, 1)) + nodes.append(tp0_node) + + # Q_0 + q0_node = onnx.helper.make_node("QuantizeLinear", ["tp0_out", "scale_1", "zp_128"], ["q0_out"], name="q0_node") + nodes.append(q0_node) + + # DQ_0 + dq0_node = onnx.helper.make_node("DequantizeLinear", ["q0_out", "scale_1", "zp_128"], ["dq0_out"], name="dq0_node") + nodes.append(dq0_node) + + # Sigmoid + sigmoid_node = onnx.helper.make_node("Sigmoid", ["dq0_out"], ["sigmoid_out"], name="sigmoid_node") + nodes.append(sigmoid_node) + + # Q_1 + q1_node = onnx.helper.make_node( + "QuantizeLinear", ["sigmoid_out", "scale_inv_255", "zp_0"], ["q1_out"], name="q1_node" + ) + nodes.append(q1_node) + + # DQ_1 + dq1_node = onnx.helper.make_node( + "DequantizeLinear", ["q1_out", "scale_inv_255", "zp_0"], ["dq1_out"], name="dq1_node" + ) + nodes.append(dq1_node) + + # DQ for mul input[1] + mul_input_1_name = build_mul_input_1_subgraph(inputs, initializers, nodes) + + # Mul + mul_node = onnx.helper.make_node("Mul", ["dq1_out", mul_input_1_name], ["mul_out"], name="mul_node") + nodes.append(mul_node) + + # Q_2 + q2_node = onnx.helper.make_node("QuantizeLinear", ["mul_out", "scale_inv_255", "zp_0"], ["q2_out"], name="q2_node") + nodes.append(q2_node) + + # DQ_2 + dq2_node = onnx.helper.make_node( + "DequantizeLinear", ["q2_out", "scale_inv_255", "zp_0"], ["dq2_out"], name="dq2_node" + ) + nodes.append(dq2_node) + + # Transpose to channel-first + tp1_node = onnx.helper.make_node("Transpose", ["dq2_out"], ["output0"], name="tp1_node", perm=(0, 3, 1, 2)) + nodes.append(tp1_node) + + graph = onnx.helper.make_graph( + nodes, + "transpose_opt_unsqueeze_dq_axis", + inputs, + outputs, + initializer=initializers, + ) + opset_imports = [ + onnx.helper.make_opsetid("", 19), + ] + qdq_model = onnx.helper.make_model(graph, opset_imports=opset_imports) + + print("[INFO]: Running onnx.checker on qdq model") + qdq_model = onnx.shape_inference.infer_shapes(qdq_model) + onnx.checker.check_model(qdq_model, True) + + print(f"[INFO]: Saving {model_path}") + onnx.save_model(qdq_model, model_path) + + +if __name__ == "__main__": + make_model( + "transpose_optimizer_qdq_fixup_unsqueeze_per_axis_dq.onnx", + subgraph_1d_input_dq, + ) + make_model( + "transpose_optimizer_in_place_transpose_unsqueeze_per_axis_dq.onnx", + subgraph_1d_const_input_dq, + ) + make_model( + "transpose_optimizer_cancel_squeeze_per_axis_dq.onnx", + subgraph_4d_input_squeeze_dq, + ) + make_model( + "transpose_optimizer_cancel_transpose_per_axis_dq.onnx", + subgraph_4d_input_transpose_dq, + ) diff --git a/onnxruntime/test/testdata/transpose_optimizer_cancel_squeeze_per_axis_dq.onnx b/onnxruntime/test/testdata/transpose_optimizer_cancel_squeeze_per_axis_dq.onnx new file mode 100644 index 0000000000000..380ce7fc7e703 Binary files /dev/null and b/onnxruntime/test/testdata/transpose_optimizer_cancel_squeeze_per_axis_dq.onnx differ diff --git a/onnxruntime/test/testdata/transpose_optimizer_cancel_transpose_per_axis_dq.onnx b/onnxruntime/test/testdata/transpose_optimizer_cancel_transpose_per_axis_dq.onnx new file mode 100644 index 0000000000000..7425c1be8a9ce Binary files /dev/null and b/onnxruntime/test/testdata/transpose_optimizer_cancel_transpose_per_axis_dq.onnx differ diff --git a/onnxruntime/test/testdata/transpose_optimization_unsqueeze_dq_axis.qdq.onnx b/onnxruntime/test/testdata/transpose_optimizer_in_place_transpose_unsqueeze_per_axis_dq.onnx similarity index 78% rename from onnxruntime/test/testdata/transpose_optimization_unsqueeze_dq_axis.qdq.onnx rename to onnxruntime/test/testdata/transpose_optimizer_in_place_transpose_unsqueeze_per_axis_dq.onnx index 4d49a86948b41..202905b660973 100644 Binary files a/onnxruntime/test/testdata/transpose_optimization_unsqueeze_dq_axis.qdq.onnx and b/onnxruntime/test/testdata/transpose_optimizer_in_place_transpose_unsqueeze_per_axis_dq.onnx differ diff --git a/onnxruntime/test/testdata/transpose_optimizer_qdq_fixup_unsqueeze_per_axis_dq.onnx b/onnxruntime/test/testdata/transpose_optimizer_qdq_fixup_unsqueeze_per_axis_dq.onnx new file mode 100644 index 0000000000000..5538681ecddc7 Binary files /dev/null and b/onnxruntime/test/testdata/transpose_optimizer_qdq_fixup_unsqueeze_per_axis_dq.onnx differ