Skip to content

Commit

Permalink
[compiler] fix error cast in shape reification pass
Browse files Browse the repository at this point in the history
  • Loading branch information
XG-zheng committed Apr 9, 2024
1 parent da7551a commit 7363a9c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
3 changes: 2 additions & 1 deletion compiler/include/byteir/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,8 @@ def ShapeReification : Pass<"byteir-shape-reification", "func::FuncOp"> {
let constructor = "mlir::createByteIRShapeReificationPass()";
let dependentDialects = [
"mlir::shape::ShapeDialect",
"mlir::tensor::TensorDialect"
"mlir::tensor::TensorDialect",
"mlir::arith::ArithDialect",
];
}

Expand Down
4 changes: 4 additions & 0 deletions compiler/lib/Transforms/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ namespace memref {
class MemRefDialect;
} // namespace memref

namespace arith {
class ArithDialect;
} // namespace arith

namespace mhlo {
class MhloDialect;
} // namespace mhlo
Expand Down
5 changes: 3 additions & 2 deletions compiler/lib/Transforms/ShapeReification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h"
#include "byteir/Dialect/mhlo/Util/ShapeInferUtil.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -59,8 +60,8 @@ struct ShapeReificationOnTensorDimPattern

// Insert cast, if needed.
if (dimOfShape.getType() != op.getType()) {
dimOfShape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
dimOfShape);
dimOfShape = rewriter.create<arith::IndexCastOp>(
op.getLoc(), op.getType(), dimOfShape);
}

rewriter.replaceOp(op, dimOfShape);
Expand Down

0 comments on commit 7363a9c

Please sign in to comment.