Skip to content

Commit

Permalink
NNAPI: Improve MatMul diagnostic output (#19721)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Re-order so that we don't get two messages for the one node.

Currently the batched matmul 'not supported' message will appear for 2D
input which is valid, which can be confusing to understand.

Change the order so we only check if batched matmul can be used when the
input ranks are > 3, as that is one of the requirements.

https://github.com/microsoft/onnxruntime/blob/c311d1faf50167e38613927e44c8a430ffcc8e89/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc#L257-L264
  • Loading branch information
skottmckay authored Apr 5, 2024
1 parent 254bdbb commit f61cca1
Showing 1 changed file with 16 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,6 @@ bool GemmOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const Nod
const OpSupportCheckParams& params) const {
// check batch matmul first, then fall back to checking single gemm/matmul
{
const bool is_supported_batch_matmul =
op_builder_helpers::IsSupportedBatchMatMul(node_unit, params.android_feature_level);
LOGS_DEFAULT(VERBOSE) << "Supported batch matmul: [" << is_supported_batch_matmul << "]";
if (is_supported_batch_matmul) {
return true;
}
}

const auto& op_type = node_unit.OpType();
Expand All @@ -312,25 +306,25 @@ bool GemmOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const Nod
const bool is_quant_gemm = quant_type == QuantizedOpType::QDQGemm;

Shape a_shape;
{
if (!GetShape(inputs[0].node_arg, a_shape))
return false;

if (a_shape.size() != 2) {
LOGS_DEFAULT(VERBOSE) << "A must be 2D";
return false;
}
Shape b_shape;
if (!GetShape(inputs[0].node_arg, a_shape) || !GetShape(inputs[1].node_arg, b_shape)) {
return false;
}

Shape b_shape;
{
if (!GetShape(inputs[1].node_arg, b_shape))
return false;
auto a_rank = a_shape.size();
auto b_rank = b_shape.size();

if (b_shape.size() != 2) {
LOGS_DEFAULT(VERBOSE) << "B must be 2D";
return false;
}
if (a_rank == 2 && b_rank == 2) {
// can potentially use FullyConnected
} else if (a_rank > 2 && b_rank > 2) {
// can maybe use our manual batched MatMul implementation
const bool is_supported_batch_matmul = op_builder_helpers::IsSupportedBatchMatMul(node_unit,
params.android_feature_level);
LOGS_DEFAULT(VERBOSE) << "Supported batch matmul: [" << is_supported_batch_matmul << "]";
return is_supported_batch_matmul;
} else {
LOGS_DEFAULT(VERBOSE) << "A and B must be 2D";
return false;
}

if (op_type == "Gemm") {
Expand Down

0 comments on commit f61cca1

Please sign in to comment.