diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp index a921422ac0b..9b23f9a0c29 100644 --- a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp +++ b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp @@ -47,10 +47,26 @@ status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, // for e.g when ab in abcd is 1x1 bool batch_ok = IMPLICATION(src_batch > 1, wei_batch == 1) && IMPLICATION(wei_batch > 1, src_batch == 1); + ACL_CHECK_SUPPORT(src_d.ndims() == 4 && src_batch != wei_batch && !batch_ok, "matmul broadcast supported only for 3D shapes and 4D shapes when " "ab is 1x1"); + if (src_d.ndims() == 4 && src_batch == wei_batch + && src_d.dims()[0] != wei_d.dims()[0]) { // 4D broadcast occurred + if (src_d.dims()[0] == 1 && wei_d.dims()[0] != 1) { // Broadcast src + ACL_CHECK_SUPPORT( + IMPLICATION(src_d.dims()[1] != 1, wei_d.dims()[1] == 1), + "acl only broadcasts one of src or wei at once"); + } + + if (wei_d.dims()[0] == 1 && src_d.dims()[0] != 1) { // Broadcast wei + ACL_CHECK_SUPPORT( + IMPLICATION(src_d.dims()[1] == 1, wei_d.dims()[1] != 1), + "acl only broadcasts one of src or wei at once"); + } + } + // ACL does not support bias bool with_bias = md.bias_desc.format_kind != format_kind::undef; ACL_CHECK_SUPPORT(with_bias, "ACL does not support bias for matmul"); diff --git a/tests/benchdnn/inputs/matmul/shapes_4d b/tests/benchdnn/inputs/matmul/shapes_4d index 7a8aa33de14..924b6607d6c 100644 --- a/tests/benchdnn/inputs/matmul/shapes_4d +++ b/tests/benchdnn/inputs/matmul/shapes_4d @@ -18,5 +18,6 @@ 74x16x54x64:74x16x64x54 1x1x35x64:113x16x64x35 1x16x38x64:105x1x64x38 +1x3x35x64:3x1x64x35 74x16x54x64:1x1x64x54n"B_full_bcast" 74x6x1x253:1x1x253x1n"dot_prod_w_B_full_bcast"