From f61cca1b8f5438deaad26dd47e36ee827d3cd139 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 5 Apr 2024 14:58:39 +1000 Subject: [PATCH] NNAPI: Improve MatMul diagnostic output (#19721) ### Description Re-order so that we don't get two messages for the one node. Currently the batched matmul 'not supported' message will appear for 2D input which is valid, which can be confusing to understand. Change the order so we only check if batched matmul can be used when the input ranks are > 3, as that is one of the requirements. https://github.com/microsoft/onnxruntime/blob/c311d1faf50167e38613927e44c8a430ffcc8e89/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc#L257-L264 --- .../builders/impl/gemm_op_builder.cc | 38 ++++++++----------- 1 file changed, 16 insertions(+), 22 deletions(-) diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gemm_op_builder.cc index 8488f7cc74a6e..66eefcd6e4840 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gemm_op_builder.cc @@ -297,12 +297,6 @@ bool GemmOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const Nod const OpSupportCheckParams& params) const { // check batch matmul first, then fall back to checking single gemm/matmul { - const bool is_supported_batch_matmul = - op_builder_helpers::IsSupportedBatchMatMul(node_unit, params.android_feature_level); - LOGS_DEFAULT(VERBOSE) << "Supported batch matmul: [" << is_supported_batch_matmul << "]"; - if (is_supported_batch_matmul) { - return true; - } } const auto& op_type = node_unit.OpType(); @@ -312,25 +306,25 @@ bool GemmOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const Nod const bool is_quant_gemm = quant_type == QuantizedOpType::QDQGemm; Shape a_shape; - { - if (!GetShape(inputs[0].node_arg, a_shape)) - return false; - - if (a_shape.size() != 2) { - LOGS_DEFAULT(VERBOSE) << "A must be 2D"; - return false; - } + Shape b_shape; + if (!GetShape(inputs[0].node_arg, a_shape) || !GetShape(inputs[1].node_arg, b_shape)) { + return false; } - Shape b_shape; - { - if (!GetShape(inputs[1].node_arg, b_shape)) - return false; + auto a_rank = a_shape.size(); + auto b_rank = b_shape.size(); - if (b_shape.size() != 2) { - LOGS_DEFAULT(VERBOSE) << "B must be 2D"; - return false; - } + if (a_rank == 2 && b_rank == 2) { + // can potentially use FullyConnected + } else if (a_rank > 2 && b_rank > 2) { + // can maybe use our manual batched MatMul implementation + const bool is_supported_batch_matmul = op_builder_helpers::IsSupportedBatchMatMul(node_unit, + params.android_feature_level); + LOGS_DEFAULT(VERBOSE) << "Supported batch matmul: [" << is_supported_batch_matmul << "]"; + return is_supported_batch_matmul; + } else { + LOGS_DEFAULT(VERBOSE) << "A and B must be 2D"; + return false; } if (op_type == "Gemm") {