Skip to content

Commit

Permalink
create memory descriptors based on the tensor dimensions
Browse files Browse the repository at this point in the history
Arm Compute Library(ACL)backend requires explicit memory format tag
iniatilization to decide wether the tensor can be computed with the
ACL kernels. Hence, the src, weights and dst memroy descriptor format is
set based on the tensor dimensions instead of using the format::any
tag.
  • Loading branch information
snadampal committed Oct 26, 2023
1 parent d88d52e commit 31356ce
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions onnxruntime/core/providers/dnnl/subgraph/dnnl_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,25 @@
namespace onnxruntime {
namespace ort_dnnl {

inline static dnnl::memory::format_tag get_default_format(const dnnl::memory::dims& tensor_dims) {
switch (tensor_dims.size()) {
case 1:
return dnnl::memory::format_tag::a;
case 2:
return dnnl::memory::format_tag::ab;
case 3:
return dnnl::memory::format_tag::abc;
case 4:
return dnnl::memory::format_tag::abcd;
case 5:
return dnnl::memory::format_tag::abcde;
case 6:
return dnnl::memory::format_tag::abcdef;
default:
return dnnl::memory::format_tag::undef;
}
}

DnnlMatMul::DnnlMatMul() {}

// This handles ONNX defined "MatMul" as well as two other variations of MatMul
Expand Down Expand Up @@ -139,14 +158,14 @@ void DnnlMatMul::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
if (transA || transBatchA) {
src_md = transposedA_md;
} else {
src_md = dnnl::memory::desc(src_dims, node.Input(IN_A).Type(), dnnl::memory::format_tag::any);
src_md = dnnl::memory::desc(src_dims, node.Input(IN_A).Type(), get_default_format(src_dims));
}

dnnl::memory::desc weights_md;
if (transB || transBatchB) {
weights_md = transposedB_md;
} else {
weights_md = dnnl::memory::desc(weights_dims, node.Input(IN_B).Type(), dnnl::memory::format_tag::any);
weights_md = dnnl::memory::desc(weights_dims, node.Input(IN_B).Type(), get_default_format(weights_dims));
}

auto output_shape = src_dims;
Expand Down Expand Up @@ -241,7 +260,7 @@ void DnnlMatMul::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
attr.set_scales_mask(DNNL_ARG_SRC, 0);
}

auto dst_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), dnnl::memory::format_tag::any);
auto dst_md = dnnl::memory::desc(output_shape, node.Output(OUT_Y).Type(), get_default_format(output_shape));

auto matmul_pd = dnnl::matmul::primitive_desc(eng, src_md, weights_md, dst_md, attr);

Expand Down

0 comments on commit 31356ce

Please sign in to comment.