Skip to content

Commit

Permalink
Fix TTIR to TTNN conversion for all gather
Browse files Browse the repository at this point in the history
  • Loading branch information
gfengTT committed Nov 6, 2024
1 parent ae93524 commit 0c10623
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -835,15 +835,9 @@ class AllGatherOpConversionPattern
LogicalResult
matchAndRewrite(ttir::AllGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType type =
mlir::cast<RankedTensorType>(adaptor.getInput().getType());
Value device = getOrInsertDevice(rewriter, op);
tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
op.getLoc(), this->getTypeConverter()->convertType(type), device);

rewriter.replaceOpWithNewOp<ttnn::AllGatherOp>(
op, this->getTypeConverter()->convertType(op.getType()), emptyOp,
adaptor.getDim());
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getDim());
return success();
}
};
Expand Down

0 comments on commit 0c10623

Please sign in to comment.