Skip to content

Commit

Permalink
Add ability for transpose optimizer to look past a DQ node if it has …
Browse files Browse the repository at this point in the history
…a constant initializer as input. This allows UnsqueezeInput/TransposeInput to modify the initializer in-place in the same way it would for a non-QDQ format model.

Shared initializers are also handled, and any additional Squeeze/Transpose added to the other usages of the initializer should cancel out when we push the same Transpose though them.

The in-place modification means we don't need to run QDQ fixup and constant folding after layout transformation. This means we do not need to enable those optimizers in a minimal build to get an optimal model post-layout transformation.
  • Loading branch information
skottmckay committed Sep 20, 2023
1 parent f297d4d commit d72f284
Show file tree
Hide file tree
Showing 9 changed files with 449 additions and 55 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ struct OptimizerCtx {
// Handlers for ops that are not in the ONNX opset, or for ONNX ops where special handling is required.
// If a handler is not found in this map, the default handlers will be used.
const HandlerMap& extended_handlers;

// DQs nodes which had a shared constant initializer as input where we updated the initializer in-place and
// inserted a Squeeze and/or Transpose on the other usages. Nodes in this set had the Squeeze/Transpose inserted.
// If we attempt to push a Transpose through them we need to look past the DQ node to try and cancel
// out the Squeeze/Transpose.
std::unordered_set<int64_t> special_cased_dq_nodes;

Check warning on line 56 in onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h#L56

Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h:56:  Add #include <unordered_set> for unordered_set<>  [build/include_what_you_use] [4]
};

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,12 @@ class NodeRef {
/// <returns>since version or default value -1</returns>
virtual int SinceVersion() const = 0;

/// <summary>
/// Get the unique id of the node.
/// </summary>
/// <returns>Id</returns>
virtual int64_t Id() const = 0;

virtual ~NodeRef(){};
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ class ApiNode final : public api::NodeRef {
void ClearAttribute(std::string_view name) override;
void SetInput(size_t i, std::string_view name) override;
std::string_view GetExecutionProviderType() const override;
virtual int SinceVersion() const override;
int SinceVersion() const override;
int64_t Id() const override;

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ApiNode);
Expand Down Expand Up @@ -417,6 +418,10 @@ int ApiNode::SinceVersion() const {
return node_.SinceVersion();
}

int64_t ApiNode::Id() const {
return node_.Index();
}

// </ApiNode>

std::optional<int64_t> ApiGraph::Opset(std::string_view domain) const {
Expand Down
28 changes: 16 additions & 12 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1005,18 +1005,22 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool
layout_transformation::TransformLayoutForEP(graph_to_transform, modified, execution_provider,
std::move(cpu_allocator), debug_graph_fn));

if (modified) {
ORT_RETURN_IF_ERROR_SESSIONID_(
graph_transformer_mgr_.ApplyTransformers(graph_to_transform, TransformerLevel::Level1, *session_logger_));

// debug the graph after the L1 transformers have run against any layout transformation changes.
// this is prior to GraphPartitioner::GetCapabilityForEP calling IExecutionProvider::GetCapability the second
// time to validate the EP that requested the layout transformation can take all nodes using the new layout.
// if that fails, this allows debugging the graph used in that GetCapability call.
if (debug_graph_fn) {
debug_graph_fn(graph_to_transform);
}
}
// Previously we ran the L1 transformers to handle constant folding of any initializers that were transposed in
// a QDQ format model. The transpose optimizer can now look past DQ nodes to directly update initializers which
// takes care of most models without needing this.
//
// if (modified) {
// ORT_RETURN_IF_ERROR_SESSIONID_(
// graph_transformer_mgr_.ApplyTransformers(graph_to_transform, TransformerLevel::Level1, *session_logger_));
//
// debug the graph after the L1 transformers have run against any layout transformation changes.
// this is prior to GraphPartitioner::GetCapabilityForEP calling IExecutionProvider::GetCapability the second
// time to validate the EP that requested the layout transformation can take all nodes using the new layout.
// if that fails, this allows debugging the graph used in that GetCapability call.
// if (debug_graph_fn) {
// debug_graph_fn(graph_to_transform);
//}
//}

return Status::OK();
};
Expand Down
120 changes: 106 additions & 14 deletions onnxruntime/test/optimizer/transpose_optimizer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
#include "core/graph/node_attr_utils.h"
#include "core/framework/op_node_proto_helper.h"
#include "core/framework/utils.h"
#include "core/optimizer/transpose_optimization/onnx_transpose_optimization.h"
#include "core/optimizer/transpose_optimization/optimizer_api.h"
#include "core/optimizer/transpose_optimization/ort_optimizer_utils.h"
#include "core/session/onnxruntime_session_options_config_keys.h"

#include "test/test_environment.h"
#include "test/optimizer/graph_transform_test_builder.h"
#include "test/providers/internal_testing/internal_testing_execution_provider.h"
#include "test/util/include/asserts.h"
#include "test/util/include/inference_session_wrapper.h"
#include "test/util/include/test_utils.h"

namespace onnxruntime {
namespace test {
Expand Down Expand Up @@ -4395,9 +4399,9 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue9671) {

SessionOptions so;
so.session_logid = "TransposeOptimizerTests.RegressionTest_GitHubIssue9671";
InferenceSession session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.Load(model_uri));
ASSERT_STATUS_OK(session_object.Initialize()); // optimizers run during initialization
InferenceSession session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
ASSERT_STATUS_OK(session.Initialize()); // optimizers run during initialization
}

// regression test for a model where the transpose optimizations incorrectly removed a node providing an implicit
Expand All @@ -4409,9 +4413,9 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue10305) {

SessionOptions so;
so.session_logid = "TransposeOptimizerTests.RegressionTest_GitHubIssue10305";
InferenceSession session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.Load(model_uri));
ASSERT_STATUS_OK(session_object.Initialize()); // optimizers run during initialization
InferenceSession session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
ASSERT_STATUS_OK(session.Initialize()); // optimizers run during initialization
}

// regression test for a model with DQ node with per-axis dequantization followed by a Transpose.
Expand All @@ -4432,18 +4436,18 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue12151) {

{
so.graph_optimization_level = TransformerLevel::Default; // off
InferenceSession session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.Load(model_uri));
ASSERT_STATUS_OK(session_object.Initialize());
ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches_orig));
InferenceSession session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
ASSERT_STATUS_OK(session.Initialize());
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig));
}

{
so.graph_optimization_level = TransformerLevel::Level1; // enable transpose optimizer
InferenceSession session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.Load(model_uri));
ASSERT_STATUS_OK(session_object.Initialize());
ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches));
InferenceSession session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
ASSERT_STATUS_OK(session.Initialize());
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches));
}

ASSERT_THAT(fetches_orig[0].Get<Tensor>().DataAsSpan<float>(),
Expand Down Expand Up @@ -4543,5 +4547,93 @@ TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) {
}
#endif
}

using namespace onnx_transpose_optimization;

Check warning on line 4551 in onnxruntime/test/optimizer/transpose_optimizer_test.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/test/optimizer/transpose_optimizer_test.cc#L4551

Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
Raw output
onnxruntime/test/optimizer/transpose_optimizer_test.cc:4551:  Do not use namespace using-directives.  Use using-declarations instead.  [build/namespaces] [5]
static CostCheckResult AlwaysPushTranspose(const api::GraphRef& /*graph*/,
const api::NodeRef& /*node*/,
const std::vector<int64_t>& /*perm*/,
const std::unordered_set<std::string>& /*outputs_leading_to_transpose*/) {
return onnx_transpose_optimization::CostCheckResult::kPushTranspose;
}

static void CheckSharedInitializerHandling(bool broadcast) {
auto model_uri = broadcast ? ORT_TSTR("testdata/transpose_optimizer_shared_initializers_broadcast.onnx")
: ORT_TSTR("testdata/transpose_optimizer_shared_initializers.onnx");

RandomValueGenerator random{123};
std::vector<int64_t> input_dims{1, 2, 2, 3};
std::vector<float> input_data = random.Gaussian<float>(input_dims, 0.0f, 1.0f);

OrtValue input;
CreateMLValue<float>(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], input_dims, input_data, &input);

NameMLValMap feeds{{"input0", input}};

std::vector<std::string> output_names{"output0"};
std::vector<OrtValue> fetches_orig;
std::vector<OrtValue> fetches;

SessionOptions so;
ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1"));
ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1"));

// get results with no modifications to the model
{
so.graph_optimization_level = TransformerLevel::Default; // off
InferenceSessionWrapper session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
ASSERT_STATUS_OK(session.Initialize());
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig));
}

{
InferenceSessionWrapper session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));

// we call the ONNX transpose optimizer directly as we want to plug in the AlwaysPushTranspose cost check.
// this is to simplify the model required to exercise the shared initializer handling.
// it also means we don't need to disable optimizers that might alter the graph before the transpose optimizer
// runs (ConstantFolding, CommonSubexpressionElimination, ConstantSharing)
Graph& graph = session.GetMutableGraph();
CPUAllocator allocator;

auto api_graph = MakeApiGraph(graph, TestCPUExecutionProvider()->CreatePreferredAllocators()[0],
/*new_node_ep*/ nullptr);

OptimizeResult result = Optimize(*api_graph, "", AlwaysPushTranspose);

ASSERT_EQ(result.error_msg, std::nullopt);
ASSERT_TRUE(result.graph_modified);
ASSERT_TRUE(graph.GraphResolveNeeded());

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["Transpose"], 0) << "The Transpose nodes should have been pushed through and canceled out.";

ASSERT_STATUS_OK(graph.Resolve());

ASSERT_STATUS_OK(session.Initialize());
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches));
}

ASSERT_THAT(fetches_orig[0].Get<Tensor>().DataAsSpan<float>(),
testing::ContainerEq(fetches[0].Get<Tensor>().DataAsSpan<float>()));
}

// test we re-use a modified shared initializer wherever possible. model has one initializer that is used by 3 DQ nodes
// and one initializer that is used by 2 Add nodes. both cases should be handled with the initializer being
// modified in-place for the first usage, and the Transpose added to the second usage being cancelled out when the

Check notice on line 4624 in onnxruntime/test/optimizer/transpose_optimizer_test.cc

View workflow job for this annotation

GitHub Actions / misspell

[misspell] onnxruntime/test/optimizer/transpose_optimizer_test.cc#L4624

"cancelled" is a misspelling of "canceled"
Raw output
./onnxruntime/test/optimizer/transpose_optimizer_test.cc:4624:92: "cancelled" is a misspelling of "canceled"
// original Transpose at the start of the model is pushed down.
TEST(TransposeOptimizerTests, SharedInitializerHandling) {
CheckSharedInitializerHandling(/*broadcast*/ false);
}

// same setup as the above test, however the initializer is broadcast to bring UnsqueezeInput into play.
// the in-place modification of the initializer for the first usage results in
// <initializer> -> Transpose -> Squeeze -> {DQ | Add}
// the later usages of the initializer should attempt to cancel out the Squeeze in UnsqueezeInput,
// followed by cancelling out the Transpose in TransposeInput.

Check notice on line 4634 in onnxruntime/test/optimizer/transpose_optimizer_test.cc

View workflow job for this annotation

GitHub Actions / misspell

[misspell] onnxruntime/test/optimizer/transpose_optimizer_test.cc#L4634

"cancelling" is a misspelling of "canceling"
Raw output
./onnxruntime/test/optimizer/transpose_optimizer_test.cc:4634:15: "cancelling" is a misspelling of "canceling"
TEST(TransposeOptimizerTests, SharedInitializerHandlingBroadcast) {
CheckSharedInitializerHandling(/*broadcast*/ true);
}
} // namespace test
} // namespace onnxruntime
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import numpy as np
import onnx
from onnx import TensorProto, helper


# Create a model with shared initializers that can be updated in-place by the transpose optimizer,
# including ones behind a DQ node. The transpose optimizer updates the first usage and inserts
# Transpose/Unsqueeze ops on the others (see UnsqueezeInput and TransposeInput).
# When we push the Transpose past other usages we should be able to cancel out those Transpose/Unsqueeze ops.
# We need 3 DQ nodes to ensure the Transpose or Unsqueeze added by the transpose optimizer is not
# removed prematurely.
def create_model(broadcast_weights: bool):
if broadcast_weights:
bias_shape = [2, 2]
bias_values = np.random.randn(2, 2)
else:
bias_shape = [1, 3, 2, 2]
bias_values = np.random.randn(1, 3, 2, 2)

graph = helper.make_graph(
name="graph",
inputs=[
helper.make_tensor_value_info("input0", TensorProto.FLOAT, [1, 2, 2, 3]),
],
initializer=[
helper.make_tensor("bias_quant", TensorProto.UINT8, bias_shape, bias_values.astype(np.uint8)),
helper.make_tensor("bias_fp32", TensorProto.FLOAT, bias_shape, bias_values.astype(np.float32)),
helper.make_tensor("dq_scale0", TensorProto.FLOAT, [], [1.5]),
helper.make_tensor("dq_zp0", TensorProto.UINT8, [], [5]),
helper.make_tensor("dq_scale1", TensorProto.FLOAT, [], [0.5]),
],
nodes=[
# Transpose input from channels last to channels first
helper.make_node("Transpose", ["input0"], ["input_T"], perm=[0, 3, 1, 2]),
helper.make_node("DequantizeLinear", ["bias_quant", "dq_scale0", "dq_zp0"], ["DQ0"], "DQ0"),
helper.make_node("Add", ["input_T", "DQ0"], ["A0"], "A0"),
helper.make_node("DequantizeLinear", ["bias_quant", "dq_scale1"], ["DQ1"], "DQ1"),
helper.make_node("Add", ["A0", "DQ1"], ["A1"], "A1"),
helper.make_node("DequantizeLinear", ["bias_quant", "dq_scale0"], ["DQ2"], "DQ2"),
helper.make_node("Add", ["A1", "DQ2"], ["A2"], "A2"),
helper.make_node("Add", ["A2", "bias_fp32"], ["A3"], "A3"),
helper.make_node("Add", ["A3", "bias_fp32"], ["A4"], "A4"),
# NCHW to NHWC
helper.make_node("Transpose", ["A4"], ["output0"], perm=[0, 2, 3, 1]),
],
outputs=[
helper.make_tensor_value_info("output0", TensorProto.FLOAT, [1, 2, 2, 3]),
],
)

model = helper.make_model(graph)
onnx.checker.check_model(model, full_check=True)
return model


if __name__ == "__main__":
model = create_model(broadcast_weights=False)
onnx.save(model, "transpose_optimizer_shared_initializers.onnx")
model = create_model(broadcast_weights=True)
onnx.save(model, "transpose_optimizer_shared_initializers_broadcast.onnx")
Binary file not shown.

0 comments on commit d72f284

Please sign in to comment.