diff --git a/xformer/Transforms/XCPatterns.td b/xformer/Transforms/XCPatterns.td index e03074c40..d4cdf8b48 100644 --- a/xformer/Transforms/XCPatterns.td +++ b/xformer/Transforms/XCPatterns.td @@ -46,8 +46,13 @@ def getExpLookupF32 def isSingleSegment : Constraint().getRank() == 2">>; -def isSingleBatch : Constraint().getDimSize(0) == 1">>; -def isMultiBatch : Constraint().getDimSize(0) != 1">>; +def isSingleBatch + : Constraint().getDimSize(0) == 1">>; +def isMultiBatch + : Constraint().getRank() == 2 && " + "$0.getType().cast().getDimSize(0) != 1) || " + "($0.getType().cast().getRank() == 3 && " + "$0.getType().cast().getDimSize(1) != 1)">>; def betaIsOne : Constraint>; @@ -61,7 +66,7 @@ def: Pat<(TFL_SoftmaxOp : $output TensorOf<[QI8]>:$input, $beta), (XC_BatchedSoftmaxOp $input, (Arith_ConstantOp (getExpLookupF32 - $output))), [(betaIsOne $beta), (isSingleSegment $input), (isMultiBatch $input)]>; + $output))), [(betaIsOne $beta), (isMultiBatch $input)]>; // Beta float activation lookup def getActivationType