Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QNN] MatMulAddFusion and Reshape Related Fusion #22494

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 153 additions & 25 deletions onnxruntime/core/optimizer/matmul_add_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,63 @@
using namespace ::onnxruntime::common;
namespace onnxruntime {

namespace {

// Attention subgraph has 4 MatMul-Add pairs, that we want to skip here because AttentionFusion will handle it.
// In such case, 3 of MatMul-Add pairs are following LN, the other one produces output which is added with LN's output.
// Use two sets to remember such patterns we already met during the graph iteration so that we can skip them directly
// if we go to other MatMul-Add pairs in the same pattern.
struct AttentionPatternCache {
bool IsAttentionPattern(const Graph& graph, const Node& matmul_node, const Node& add_node) {
const Node* parent_node = graph.GetProducerNode(matmul_node.InputDefs()[0]->Name());
if (attn_ln_nodes.count(parent_node) > 0 || attn_add_nodes.count(&add_node) > 0) {
return true;
}

if (parent_node && parent_node->OpType() == "LayerNormalization") {
unsigned int add_count = 0;
unsigned int matmul_count = 0;
unsigned int shape_count = 0;
const Node* ln_add_node = nullptr;
for (auto it = parent_node->OutputNodesBegin(); it != parent_node->OutputNodesEnd(); ++it) {
std::string op_type = (*it).OpType();
if (op_type == "Add") {
ln_add_node = &(*it);
add_count++;
} else if (op_type == "MatMul") {
matmul_count++;
} else if (op_type == "Shape") {
shape_count++;
}
}

if (add_count == 1 && matmul_count == 3 && shape_count == parent_node->GetOutputEdgesCount() - 4) {
size_t index = ln_add_node->InputDefs()[0]->Name() == parent_node->OutputDefs()[0]->Name() ? 1 : 0;
const Node* attn_add_node = graph.GetProducerNode(ln_add_node->InputDefs()[index]->Name());
if (attn_add_node && attn_add_node->OpType() == "Add") {
attn_ln_nodes.insert(parent_node);
attn_add_nodes.insert(attn_add_node);
return true;
}
}
}

return false;
}

std::unordered_set<const Node*> attn_ln_nodes;
std::unordered_set<const Node*> attn_add_nodes;

Check warning on line 59 in onnxruntime/core/optimizer/matmul_add_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/matmul_add_fusion.cc:59: Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4]
};

} // namespace

Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();

// Cache for skipping Attention subgraph pattern.
AttentionPatternCache attn_pattern_cache;

for (auto node_index : node_topology_list) {
auto* node_ptr = graph.GetNode(node_index);
if (!node_ptr)
Expand Down Expand Up @@ -65,58 +118,133 @@
// Gemm only support Matrix, need to check the shape of MatMul and Add
auto matmul_a_shape = matmul_input_defs[0]->Shape();
auto matmul_b_shape = matmul_input_defs[1]->Shape();
if (nullptr == matmul_a_shape || nullptr == matmul_b_shape) {
if (nullptr == matmul_a_shape || nullptr == matmul_b_shape || matmul_b_shape->dim_size() != 2) {
continue;
}

if (2 != matmul_a_shape->dim_size() || 2 != matmul_b_shape->dim_size()) {
// Gemm only support Matrix
continue;
bool need_reshape = matmul_a_shape->dim_size() != 2;
const auto& dim_n = matmul_b_shape->dim(1);
InlinedVector<int64_t> shape_values;
int64_t m = 0, k = 0, n = 0;
if (need_reshape) {
// Only check and skip Attention pattern here because normally input to Attention is 4D.
if (attn_pattern_cache.IsAttentionPattern(graph, matmul_node, add_node)) {
continue;
}

// Logically we can use Shape-Concat to produce shape input for Reshape, to keep it simple, we require
// both inputs have concrete shape for now, we can add dynamic shape support in future.
auto a_shape = utils::GetTensorShapeFromTensorShapeProto(*matmul_a_shape);
if (a_shape.Size() == -1) {
continue;
}

const auto& dim_k = matmul_b_shape->dim(0);
if (!utils::HasDimValue(dim_k) || !utils::HasDimValue(dim_n)) {
continue;
}

shape_values = a_shape.AsShapeVector();
// If a_shape is 1D, m is 1 from SizeToDimension() with empty dimension interval.
m = a_shape.SizeToDimension(a_shape.NumDimensions() - 1);
k = dim_k.dim_value();
n = dim_n.dim_value();
}

const auto& matmul_output = *matmul_node.OutputDefs()[0];

auto matmul_output_name = matmul_output.Name();
auto gemm_input_defs = matmul_input_defs;
if (matmul_output_name == add_input_defs[0]->Name()) {
// matmul output as Add_A, should use Add_B as input C for gemm
gemm_input_defs.push_back(add_input_defs[1]);
} else {
// matmul output as Add_B, should use Add_A as input C for gemm
gemm_input_defs.push_back(add_input_defs[0]);
}
int bias_idx = matmul_output_name == add_input_defs[0]->Name() ? 1 : 0;
gemm_input_defs.push_back(add_input_defs[bias_idx]);

// valid bias_shapes are (N) or (1, N) or (M, 1) or (M, N) as
// GEMM only supports unidirectional broadcast on the bias input C
if (!gemm_input_defs.back()->Shape()) {
continue;
}
const auto& bias_shape = *gemm_input_defs.back()->Shape();
const auto& M = matmul_output.Shape()->dim()[0];
const auto& N = matmul_output.Shape()->dim()[1];
auto dim_has_value_1 = [](const TensorShapeProto_Dimension& dim) {
return dim.has_dim_value() && dim.dim_value() == 1;
};

bool valid = ((bias_shape.dim_size() == 1 && bias_shape.dim()[0] == N) ||
(bias_shape.dim_size() == 2 && dim_has_value_1(bias_shape.dim()[0]) && bias_shape.dim()[1] == N) ||
(bias_shape.dim_size() == 2 && bias_shape.dim()[0] == M &&
(dim_has_value_1(bias_shape.dim()[1]) || bias_shape.dim()[1] == N)));
bool valid = ((bias_shape.dim_size() == 1 && bias_shape.dim(0) == dim_n) ||
(!need_reshape && bias_shape.dim_size() == 2 && dim_has_value_1(bias_shape.dim(0)) &&
bias_shape.dim(1) == dim_n) ||
(!need_reshape && bias_shape.dim_size() == 2 && bias_shape.dim(0) == matmul_a_shape->dim(0) &&
(dim_has_value_1(bias_shape.dim(1)) || bias_shape.dim(1) == dim_n)));
if (!valid) {
continue;
}

Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion/"),
"Gemm",
"fused Matmul and Add " + add_node.OpType(),
gemm_input_defs,
{});
auto gemm_output_defs = add_node.MutableOutputDefs();
Node* input_node = nullptr;
Node* output_node = nullptr;
if (need_reshape) {
auto add_reshape = [&](const InlinedVector<int64_t>& shape, Graph& graph, bool is_input) -> Node* {
skottmckay marked this conversation as resolved.
Show resolved Hide resolved
const std::string name = is_input ? "gemm_input" : "gemm_output";

Check warning on line 185 in onnxruntime/core/optimizer/matmul_add_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/matmul_add_fusion.cc:185: Add #include <string> for string [build/include_what_you_use] [4]
ONNX_NAMESPACE::TensorProto shape_initializer_proto;
shape_initializer_proto.set_name(graph.GenerateNodeName(name + "_shape"));
shape_initializer_proto.add_dims(static_cast<int64_t>(shape.size()));
shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
shape_initializer_proto.set_raw_data(shape.data(), shape.size() * sizeof(int64_t));
NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto);
ONNX_NAMESPACE::TypeProto new_arg_type;
const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(
gemm_input_defs[0]->TypeAsProto()->tensor_type().elem_type());
new_arg_type.mutable_tensor_type()->set_elem_type(element_type);
new_arg_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(m);
new_arg_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(is_input ? k : n);
NodeArg* new_arg = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(name + "_reshape_arg"), &new_arg_type);
Node& reshape_node = graph.AddNode(graph.GenerateNodeName(name + "_reshape"), "Reshape", "Reshape for " + name,
{is_input ? gemm_input_defs[0] : new_arg, shape_arg},
{is_input ? new_arg : gemm_output_defs[0]});
reshape_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType());
return &reshape_node;
};

input_node = add_reshape({m, k}, graph, true);
gemm_input_defs[0] = input_node->MutableOutputDefs()[0];
shape_values.back() = n;
output_node = add_reshape(shape_values, graph, false);
gemm_output_defs[0] = output_node->MutableInputDefs()[0];
}

// Assign provider to this new node. Provider should be same as the provider for old node.
Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion"), "Gemm",
"fused Matmul and Add", gemm_input_defs, gemm_output_defs);
gemm_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType());

// move output definitions and edges from act_node to gemm_node. delete gemm_node and act_node.
graph_utils::FinalizeNodeFusion(graph, {matmul_node, add_node}, gemm_node);
if (need_reshape) {
graph.AddEdge(input_node->Index(), gemm_node.Index(), 0, 0);
graph.AddEdge(gemm_node.Index(), output_node->Index(), 0, 0);
} else {
input_node = &gemm_node;
output_node = &gemm_node;
}

auto matmul_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(matmul_node);
for (auto cur = matmul_input_edges.cbegin(), end = matmul_input_edges.cend(); cur != end; ++cur) {
if (cur->dst_arg_index == 0) {
graph.AddEdge(cur->src_node, input_node->Index(), cur->src_arg_index, 0);
} else if (cur->dst_arg_index == 1) {
graph.AddEdge(cur->src_node, gemm_node.Index(), cur->src_arg_index, 1);
}
}

graph_utils::GraphEdge::RemoveGraphEdges(graph, matmul_input_edges);
auto add_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(add_node);
for (auto cur = add_input_edges.cbegin(), end = add_input_edges.cend(); cur != end; ++cur) {
if (cur->dst_arg_index == bias_idx) {
graph.AddEdge(cur->src_node, gemm_node.Index(), cur->src_arg_index, 2);
break;
}
}

graph_utils::GraphEdge::RemoveGraphEdges(graph, add_input_edges);
graph_utils::RemoveNodeOutputEdges(graph, matmul_node);
graph_utils::ReplaceDownstreamNodeInput(graph, add_node, 0, *output_node, 0);
graph.RemoveNode(matmul_node.Index());
graph.RemoveNode(add_node.Index());

modified = true;
}
Expand Down
51 changes: 51 additions & 0 deletions onnxruntime/core/optimizer/reshape_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
fused_count++;
LOGS(logger, INFO) << "Fused reshape node: " << reshape.OutputDefs()[0]->Name();
modified = true;
} else if (ReshapeFusion::FuseContiguousReshapes(reshape, graph)) {
modified = true;
}
}

Expand Down Expand Up @@ -452,4 +454,53 @@
return true;
}

bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph) {
InlinedVector<std::reference_wrapper<Node>> contiguous_reshapes{reshape};
InlinedVector<int64_t> shape_value;
while (true) {
Node& curr_node = contiguous_reshapes.back();
if (graph.NodeProducesGraphOutput(curr_node) || curr_node.GetOutputEdgesCount() != 1) {
break;
}

Node* next_node = graph.GetNode(curr_node.OutputNodesBegin()->Index());
if (next_node->OpType() != "Reshape" && next_node->OpType() != "Squeeze" && next_node->OpType() != "Unsqueeze") {
break;
}

auto shape = next_node->OutputDefs()[0]->Shape();
if (!shape) {
break;
}

auto tensor_shape = utils::GetTensorShapeFromTensorShapeProto(*shape);
if (tensor_shape.Size() == -1) {
break;
}

shape_value = tensor_shape.AsShapeVector();
contiguous_reshapes.emplace_back(*next_node);
}

if (contiguous_reshapes.size() < 2) {
return false;
}

const std::string& name = contiguous_reshapes[0].get().Name();

Check warning on line 489 in onnxruntime/core/optimizer/reshape_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/reshape_fusion.cc:489: Add #include <string> for string [build/include_what_you_use] [4]
ONNX_NAMESPACE::TensorProto shape_initializer_proto;
shape_initializer_proto.set_name(graph.GenerateNodeName(name + "_new_shape"));
shape_initializer_proto.add_dims(static_cast<int64_t>(shape_value.size()));
shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
shape_initializer_proto.set_raw_data(shape_value.data(), shape_value.size() * sizeof(int64_t));
NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto);
Node& reshape_node = graph.AddNode(graph.GenerateNodeName(name + "_new_reshape"), "Reshape", "Reshape for " + name,
{contiguous_reshapes[0].get().MutableInputDefs()[0], shape_arg},
{contiguous_reshapes.back().get().MutableOutputDefs()[0]});
reshape_node.SetExecutionProviderType(contiguous_reshapes[0].get().GetExecutionProviderType());

graph_utils::FinalizeNodeFusion(graph, contiguous_reshapes, reshape_node);

return true;
}

} // namespace onnxruntime
5 changes: 5 additions & 0 deletions onnxruntime/core/optimizer/reshape_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ class ReshapeFusion : public GraphTransformer {
static bool Is_One_Element_Input(const Node& cur_node, int index);
static bool Is_One_Element_Output_Subgraph(Graph& graph, const NodeArg& root_input, const Node& concat,
int index, gsl::span<const int64_t> shape_value, const logging::Logger& logger);

// Remove contiguous Reshape/Squeeze/Unsqueeze if the shape info is concrete.
// For some EP, such reshape Ops are not no-op, such as QNN EP, memory is allocated for each output,
// so this fusion can help to reduce memory usage on such devices.
static bool FuseContiguousReshapes(Node& reshape, Graph& graph);
};

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
#include <core/providers/common.h>

#include "core/providers/shared/utils/utils.h"
#include "core/framework/tensorprotoutils.h"
#include "core/providers/cpu/tensor/transpose.h"
#include "core/common/safeint.h"

namespace onnxruntime {
Expand Down Expand Up @@ -271,37 +269,6 @@ Status BaseOpBuilder::SetOutputQParamEqualToInputIfNearlyEqual(QnnModelWrapper&
return Status::OK();
}

Status BaseOpBuilder::TransposeInitializer(const QnnModelWrapper& qnn_model_wrapper,
const onnx::TensorProto& initializer,
const std::vector<size_t>& perm,
std::vector<uint8_t>& transposed_data) const {
const DataTypeImpl* tensor_dtype = DataTypeImpl::TensorTypeFromONNXEnum(initializer.data_type())->GetElementType();
const auto tensor_shape_dims = onnxruntime::utils::GetTensorShapeFromTensorProto(initializer);
TensorShape tensor_shape{tensor_shape_dims};
AllocatorPtr cpu_allocator = std::make_shared<CPUAllocator>();
Tensor in_tensor = Tensor(tensor_dtype, tensor_shape, cpu_allocator);

auto rank = perm.size();
std::vector<int64_t> new_tensor_shape_dims;
std::vector<size_t> permutations;
new_tensor_shape_dims.reserve(rank);
permutations.reserve(rank);
for (int64_t p : perm) {
permutations.push_back(p);
new_tensor_shape_dims.push_back(tensor_shape_dims[p]);
}

TensorShape new_tensor_shape(new_tensor_shape_dims);
Tensor out_tensor = Tensor(tensor_dtype, new_tensor_shape, cpu_allocator);
ORT_RETURN_IF_ERROR(onnxruntime::utils::TensorProtoToTensor(
Env::Default(), qnn_model_wrapper.GetGraphViewer().ModelPath(), initializer, in_tensor));
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(permutations, in_tensor, out_tensor));
onnx::TensorProto new_tensor_proto = onnxruntime::utils::TensorToTensorProto(out_tensor, "test");
ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(new_tensor_proto, transposed_data));

return Status::OK();
}

Status BaseOpBuilder::ProcessAxisAttribute(const QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
Qnn_Scalar_t& axis_qnn_scalar,
Expand Down
Loading
Loading