Skip to content

Commit

Permalink
Fix CPU constant folding not reverting the node to its previous EP (m…
Browse files Browse the repository at this point in the history
…icrosoft#17399)

A recent change was made in
microsoft@5a83a67
to make `ep_type` a reference instead of having it be a copy, presumably
to avoid assigning strings (so `auto& ep_type =
node->GetExecutionProviderType()` instead of `auto ep_type =
node->GetExecutionProviderType()`). The problem with this change is that
calling `node->SetExecutionProviderType(kCpuExecutionProvider)` will
change the value of the reference itself, which means that it's
impossible to revert the node to its previous EP.

This change fixes this bug and adds an optimization over the previous
approach by only assigning a string when we know that we are dealing
with a non-CPU node.
  • Loading branch information
PatriceVignola authored Sep 12, 2023
1 parent bf6d696 commit 8ad9ab1
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 7 deletions.
19 changes: 12 additions & 7 deletions onnxruntime/core/optimizer/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,19 +185,24 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
fetch_mlvalue_idxs.push_back(info.GetMLValueIndex(node_out->Name()));
}

auto& ep_type = node->GetExecutionProviderType();
const bool node_on_cpu_ep = ep_type == kCpuExecutionProvider;
const bool node_on_cpu_ep = node->GetExecutionProviderType() == kCpuExecutionProvider;

std::unique_ptr<const OpKernel> kernel;

// override the EP assigned to the node so that it will use the CPU kernel for Compute.
if (!node_on_cpu_ep) {
// We need to copy the string here instead of taking a reference to it since node->SetExecutionProviderType
// will change the value of the reference
auto ep_type = node->GetExecutionProviderType();

// override the EP assigned to the node so that it will use the CPU kernel for Compute.
node->SetExecutionProviderType(kCpuExecutionProvider);
}

auto kernel = info.CreateKernel(node);
kernel = info.CreateKernel(node);

// undo the EP change to the value that was assigned at graph partitioning time
if (!node_on_cpu_ep) {
// undo the EP change to the value that was assigned at graph partitioning time
node->SetExecutionProviderType(ep_type);
} else {
kernel = info.CreateKernel(node);
}

// We currently constant fold using the CPU EP only.
Expand Down
30 changes: 30 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,36 @@ TEST_F(GraphTransformationTests, ConstantFoldingNodesOnDifferentEP) {
}
}

TEST_F(GraphTransformationTests, ConstantFoldingUnsupportedFloat16) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "constant_float16_mul.onnx";
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_));
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Mul"] == 1);
std::unique_ptr<CPUExecutionProvider> e =
std::make_unique<CPUExecutionProvider>(CPUExecutionProviderInfo());
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
std::make_unique<ConstantFolding>(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1));

// assign all nodes to CUDA. the constant folding should try folding the node on the CPU and fail, thus leaving the
// EP as CUDA and not constant folding the node.
for (auto& node : graph.Nodes()) {
node.SetExecutionProviderType(kCudaExecutionProvider);
}

ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Mul"] == 1);

// all nodes should still be on CUDA
for (auto& node : graph.Nodes()) {
EXPECT_STREQ(node.GetExecutionProviderType().c_str(), kCudaExecutionProvider);
}
}

TEST_F(GraphTransformationTests, ConstantFoldingSubgraph) {
TensorProto value_tensor;
value_tensor.add_dims(1);
Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/test/testdata/transform/constant_float16_mul.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
onnxruntime_test:�
2c1c1_node"Constant*
value*
*�xBc1v�
3c2c2_node"Constant*
value*
*��Bc2v�

c1
c2
mul_outputmul"Mul float16_mulb

mul_output



B
Expand Down

0 comments on commit 8ad9ab1

Please sign in to comment.