Skip to content

Commit

Permalink
[compiler] chore
Browse files Browse the repository at this point in the history
  • Loading branch information
YellowHCH committed May 30, 2024
1 parent d4fb8ba commit 770f38c
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 6 deletions.
51 changes: 50 additions & 1 deletion compiler/lib/Conversion/MemrefToByre/MemrefToByre.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,54 @@ class ConvertReshapeLikeOpToByrePattern : public OpConversionPattern<OpTy> {
}
};

class ConvertCastOpToByrePattern : public OpConversionPattern<memref::CastOp> {
public:
ConvertCastOpToByrePattern(MLIRContext *ctx)
: OpConversionPattern<memref::CastOp>(ctx) {}

LogicalResult
matchAndRewrite(memref::CastOp op, memref::CastOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (auto subview = op.getSource().getDefiningOp<memref::SubViewOp>()) {
if (!subview.getSource().getType().getLayout().isIdentity())
return failure();
if (!op.getType().cast<MemRefType>().getLayout().isIdentity())
return failure();

rewriter.replaceOpWithNewOp<byre::AliasOp>(op, op.getType(),
subview.getSource(), 0);
return success();
}

return failure();
}
};

class ConvertCollapseShapeOpToByrePattern
: public OpConversionPattern<memref::CollapseShapeOp> {
public:
ConvertCollapseShapeOpToByrePattern(MLIRContext *ctx)
: OpConversionPattern<memref::CollapseShapeOp>(ctx) {}

LogicalResult
matchAndRewrite(memref::CollapseShapeOp op,
memref::CollapseShapeOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (auto subview = op.getSrc().getDefiningOp<memref::SubViewOp>()) {
if (!subview.getSource().getType().getLayout().isIdentity())
return failure();
if (!op.getType().getLayout().isIdentity())
return failure();

rewriter.replaceOpWithNewOp<byre::AliasOp>(op, op.getType(),
subview.getSource(), 0);
return success();
}

return failure();
}
};

class ConvertViewOpToByrePattern : public OpConversionPattern<memref::ViewOp> {
public:
ConvertViewOpToByrePattern(MLIRContext *ctx)
Expand Down Expand Up @@ -196,7 +244,8 @@ void mlir::populateMemrefToByrePattern(RewritePatternSet &patterns) {
ConvertGetGlobalOpToByrePattern,
ConvertReshapeLikeOpToByrePattern<memref::CollapseShapeOp>,
ConvertReshapeLikeOpToByrePattern<memref::ExpandShapeOp>,
ConvertSubViewOpToByrePattern>(patterns.getContext());
ConvertSubViewOpToByrePattern, ConvertCastOpToByrePattern,
ConvertCollapseShapeOpToByrePattern>(patterns.getContext());
}

std::unique_ptr<OperationPass<func::FuncOp>>
Expand Down
9 changes: 5 additions & 4 deletions compiler/lib/Dialect/mhlo/Transforms/GenericFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ bool isFusibleStart(Operation *op) { return true; }
bool isFusibleTrigger(Operation *op) {
if (op->hasTrait<::mlir::OpTrait::Elementwise>() ||
op->hasTrait<hlo::OpTrait::BroadcastingElementwise>() ||
isa<mhlo::ReshapeOp>(op) || isCustomMhloRngOp(op)) {
isa<mhlo::ReshapeOp, mhlo::TransposeOp>(op) || isCustomMhloRngOp(op)) {
return true;
}

Expand Down Expand Up @@ -101,8 +101,8 @@ bool isFusibleWith(Operation *target, Operation *start) {
return target->hasTrait<::mlir::OpTrait::Elementwise>() ||
target->hasTrait<hlo::OpTrait::BroadcastingElementwise>() ||
isSplatMhloConstantLike(target) ||
isa<mhlo::BroadcastInDimOp, mhlo::BroadcastOp, mhlo::ReshapeOp>(
target) ||
isa<mhlo::BroadcastInDimOp, mhlo::BroadcastOp, mhlo::ReshapeOp,
mhlo::TransposeOp>(target) ||
isCustomMhloRngOp(target);
}

Expand All @@ -115,7 +115,8 @@ bool isFusibleWithNoElementwiseFuse(Operation *target, Operation * /*start*/) {
bool isValidSingleOp(Operation *op) {
return op->hasTrait<::mlir::OpTrait::Elementwise>() ||
op->hasTrait<hlo::OpTrait::BroadcastingElementwise>() ||
isa<mhlo::BroadcastInDimOp, mhlo::BroadcastOp, mhlo::IotaOp>(op) ||
isa<mhlo::BroadcastInDimOp, mhlo::BroadcastOp, mhlo::IotaOp,
mhlo::TransposeOp>(op) ||
isCustomMhloRngOp(op);
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/python/byteir/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def compile(

### legalize stablehlo to mhlo
with context:
PassManager.parse("builtin.module(canonicalize,stablehlo-legalize-to-hlo,canonicalize)").run(module.operation)
PassManager.parse("builtin.module(canonicalize,stablehlo-legalize-to-hlo,canonicalize-ext,canonicalize)").run(module.operation)
_print_verbose(module, "// IR Dump After Legalize to HLO:") if verbose else ...

### parse output options from output_file_path
Expand Down

0 comments on commit 770f38c

Please sign in to comment.