From 8977ef9620ada75fe696bf7e18201d4f9a24a99e Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Mon, 12 Aug 2024 13:37:24 +0200 Subject: [PATCH] Handle rank 0 and 1 --- compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp index d12d129fef3..bceb86e2f5e 100644 --- a/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp +++ b/compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp @@ -136,7 +136,7 @@ bool fuse_mul_with_fc(luci::CircleFullyConnected *fc) // Check that all dimensions are ones, checks broadcast capabilites. // Last dimesion of multiplication must be compatible with FC. // N-D case (N>1): - if (multiplication->rank() >= 1) + if (multiplication->rank() > 1) { // Check channel-wise broadcasting: for (uint32_t i = 0; i < rank - 1; i++) @@ -144,7 +144,12 @@ bool fuse_mul_with_fc(luci::CircleFullyConnected *fc) // Check the last dimesion of Mul is the same with the first dimension of FullyConnected RETURN_FALSE_UNLESS(multiplication->dim(rank - 1) == weights->dim(0)); } - // Scalar case: + // 1-D or scalar case: + else if (multiplication->rank() == 1) + { + RETURN_FALSE_UNLESS(multiplication->size() == 1 || + multiplication->size() == weights->dim(0)); + } else if (multiplication->rank() == 0) { RETURN_FALSE_UNLESS(multiplication->size() == 1);