Skip to content

Commit

Permalink
Handle batched softmax for three ranks
Browse files Browse the repository at this point in the history
  • Loading branch information
panickal-xmos committed Jul 14, 2024
1 parent ceeb41e commit 7fa5e87
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions xformer/Transforms/XCPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,13 @@ def getExpLookupF32
def isSingleSegment
: Constraint<CPred<"$0.getType().cast<ShapedType>().getRank() == 2">>;

def isSingleBatch : Constraint<CPred<"$0.getType().cast<ShapedType>().getDimSize(0) == 1">>;
def isMultiBatch : Constraint<CPred<"$0.getType().cast<ShapedType>().getDimSize(0) != 1">>;
def isSingleBatch
: Constraint<CPred<"$0.getType().cast<ShapedType>().getDimSize(0) == 1">>;
def isMultiBatch
: Constraint<CPred<"($0.getType().cast<ShapedType>().getRank() == 2 && "
"$0.getType().cast<ShapedType>().getDimSize(0) != 1) || "
"($0.getType().cast<ShapedType>().getRank() == 3 && "
"$0.getType().cast<ShapedType>().getDimSize(1) != 1)">>;

def betaIsOne : Constraint<CPred<"$0.getValue().convertToFloat() == 1.0">>;

Expand All @@ -61,7 +66,7 @@ def:
Pat<(TFL_SoftmaxOp
: $output TensorOf<[QI8]>:$input, $beta),
(XC_BatchedSoftmaxOp $input, (Arith_ConstantOp (getExpLookupF32
$output))), [(betaIsOne $beta), (isSingleSegment $input), (isMultiBatch $input)]>;
$output))), [(betaIsOne $beta), (isMultiBatch $input)]>;

// Beta float activation lookup
def getActivationType
Expand Down

0 comments on commit 7fa5e87

Please sign in to comment.