Skip to content

Commit

Permalink
Merge branch 'jiwaszki/fuse_mul_fc_luci_pass' into jiwaszki/fuse_mul_fc
Browse files Browse the repository at this point in the history
  • Loading branch information
jiwaszki committed Aug 12, 2024
2 parents 1b6c71f + 8977ef9 commit dbed1b9
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions compiler/luci/pass/src/FuseMulWithFullyConnectedPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,20 @@ bool fuse_mul_with_fc(luci::CircleFullyConnected *fc)
// Check that all dimensions are ones, checks broadcast capabilites.
// Last dimesion of multiplication must be compatible with FC.
// N-D case (N>1):
if (multiplication->rank() >= 1)
if (multiplication->rank() > 1)
{
// Check channel-wise broadcasting:
for (uint32_t i = 0; i < rank - 1; i++)
RETURN_FALSE_UNLESS(multiplication->dim(i).value() == 1);
// Check the last dimesion of Mul is the same with the first dimension of FullyConnected
RETURN_FALSE_UNLESS(multiplication->dim(rank - 1) == weights->dim(0));
}
// Scalar case:
// 1-D or scalar case:
else if (multiplication->rank() == 1)
{
RETURN_FALSE_UNLESS(multiplication->size<loco::DataType::FLOAT32>() == 1 ||
multiplication->size<loco::DataType::FLOAT32>() == weights->dim(0));
}
else if (multiplication->rank() == 0)
{
RETURN_FALSE_UNLESS(multiplication->size<loco::DataType::FLOAT32>() == 1);
Expand Down

0 comments on commit dbed1b9

Please sign in to comment.