diff --git a/compiler/lib/Conversion/MemrefToByre/MemrefToByre.cpp b/compiler/lib/Conversion/MemrefToByre/MemrefToByre.cpp index f759d2d0b..0016ee5f1 100644 --- a/compiler/lib/Conversion/MemrefToByre/MemrefToByre.cpp +++ b/compiler/lib/Conversion/MemrefToByre/MemrefToByre.cpp @@ -50,6 +50,54 @@ class ConvertReshapeLikeOpToByrePattern : public OpConversionPattern { } }; +class ConvertCastOpToByrePattern : public OpConversionPattern { +public: + ConvertCastOpToByrePattern(MLIRContext *ctx) + : OpConversionPattern(ctx) {} + + LogicalResult + matchAndRewrite(memref::CastOp op, memref::CastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (auto subview = op.getSource().getDefiningOp()) { + if (!subview.getSource().getType().getLayout().isIdentity()) + return failure(); + if (!op.getType().cast().getLayout().isIdentity()) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getType(), + subview.getSource(), 0); + return success(); + } + + return failure(); + } +}; + +class ConvertCollapseShapeOpToByrePattern + : public OpConversionPattern { +public: + ConvertCollapseShapeOpToByrePattern(MLIRContext *ctx) + : OpConversionPattern(ctx) {} + + LogicalResult + matchAndRewrite(memref::CollapseShapeOp op, + memref::CollapseShapeOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (auto subview = op.getSrc().getDefiningOp()) { + if (!subview.getSource().getType().getLayout().isIdentity()) + return failure(); + if (!op.getType().getLayout().isIdentity()) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getType(), + subview.getSource(), 0); + return success(); + } + + return failure(); + } +}; + class ConvertViewOpToByrePattern : public OpConversionPattern { public: ConvertViewOpToByrePattern(MLIRContext *ctx) @@ -196,7 +244,8 @@ void mlir::populateMemrefToByrePattern(RewritePatternSet &patterns) { ConvertGetGlobalOpToByrePattern, ConvertReshapeLikeOpToByrePattern, ConvertReshapeLikeOpToByrePattern, - ConvertSubViewOpToByrePattern>(patterns.getContext()); + ConvertSubViewOpToByrePattern, ConvertCastOpToByrePattern, + ConvertCollapseShapeOpToByrePattern>(patterns.getContext()); } std::unique_ptr> diff --git a/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp b/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp index 377911d16..653041f2f 100644 --- a/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp +++ b/compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp @@ -73,7 +73,7 @@ bool isFusibleStart(Operation *op) { return true; } bool isFusibleTrigger(Operation *op) { if (op->hasTrait<::mlir::OpTrait::Elementwise>() || op->hasTrait() || - isa(op) || isCustomMhloRngOp(op)) { + isa(op) || isCustomMhloRngOp(op)) { return true; } @@ -101,8 +101,8 @@ bool isFusibleWith(Operation *target, Operation *start) { return target->hasTrait<::mlir::OpTrait::Elementwise>() || target->hasTrait() || isSplatMhloConstantLike(target) || - isa( - target) || + isa(target) || isCustomMhloRngOp(target); } @@ -115,7 +115,8 @@ bool isFusibleWithNoElementwiseFuse(Operation *target, Operation * /*start*/) { bool isValidSingleOp(Operation *op) { return op->hasTrait<::mlir::OpTrait::Elementwise>() || op->hasTrait() || - isa(op) || + isa(op) || isCustomMhloRngOp(op); } diff --git a/compiler/python/byteir/compile.py b/compiler/python/byteir/compile.py index d45100d96..a0c948f90 100644 --- a/compiler/python/byteir/compile.py +++ b/compiler/python/byteir/compile.py @@ -375,7 +375,7 @@ def compile( ### legalize stablehlo to mhlo with context: - PassManager.parse("builtin.module(canonicalize,stablehlo-legalize-to-hlo,canonicalize)").run(module.operation) + PassManager.parse("builtin.module(canonicalize,stablehlo-legalize-to-hlo,canonicalize-ext,canonicalize)").run(module.operation) _print_verbose(module, "// IR Dump After Legalize to HLO:") if verbose else ... ### parse output options from output_file_path