diff --git a/mlir/include/air/Dialect/AIRRt/AIRRtOps.td b/mlir/include/air/Dialect/AIRRt/AIRRtOps.td index baeb696cb..43cbed364 100644 --- a/mlir/include/air/Dialect/AIRRt/AIRRtOps.td +++ b/mlir/include/air/Dialect/AIRRt/AIRRtOps.td @@ -189,6 +189,7 @@ def AIRRt_AllocOp: AIRRt_Op<"alloc", []> { let description = [{ AIRRt Allocation Op }]; + let hasCanonicalizer = 1; } def AIRRt_DeallocOp: AIRRt_Op<"dealloc", []> { @@ -201,6 +202,7 @@ def AIRRt_DeallocOp: AIRRt_Op<"dealloc", []> { let description = [{ AIRRt Deallocation Op }]; + let hasCanonicalizer = 1; } def AIRRt_WaitAllOp : AIRRt_Op<"wait_all", []> { diff --git a/mlir/lib/Dialect/AIRRt/IR/AIRRtDialect.cpp b/mlir/lib/Dialect/AIRRt/IR/AIRRtDialect.cpp index 055c3171b..24496beb0 100644 --- a/mlir/lib/Dialect/AIRRt/IR/AIRRtDialect.cpp +++ b/mlir/lib/Dialect/AIRRt/IR/AIRRtDialect.cpp @@ -70,7 +70,35 @@ static LogicalResult FoldWaitAll(WaitAllOp op, PatternRewriter &rewriter) { return failure(); } +static LogicalResult FoldAlloc(AllocOp op, PatternRewriter &rewriter) { + auto memref = op.getResult(); + if (!llvm::range_size(memref.getUsers())) { + rewriter.eraseOp(op); + return success(); + } + return failure(); +} + +static LogicalResult FoldDealloc(DeallocOp op, PatternRewriter &rewriter) { + auto memref = op.getOperand(); + if (llvm::range_size(memref.getUsers()) == 1) { + rewriter.eraseOp(op); + return success(); + } + return failure(); +} + void WaitAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(FoldWaitAll); } + +void AllocOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(FoldAlloc); +} + +void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(FoldDealloc); +} diff --git a/mlir/test/Dialect/AIRRt/airrt_canonicalize.mlir b/mlir/test/Dialect/AIRRt/airrt_canonicalize.mlir index 566a747dd..caaa530c3 100644 --- a/mlir/test/Dialect/AIRRt/airrt_canonicalize.mlir +++ b/mlir/test/Dialect/AIRRt/airrt_canonicalize.mlir @@ -32,3 +32,11 @@ func.func @wait_all_1(%e0 : !airrt.event, %e1 : !airrt.event, %e2 : !airrt.event %7 = airrt.wait_all %6 : !airrt.event return %7 : !airrt.event } + +// CHECK-LABEL: alloc_dealloc +// CHECK-NEXT: return +func.func @alloc_dealloc() { + %0 = airrt.alloc : memref<1x4x4x16xi32, 1> + airrt.dealloc %0 : memref<1x4x4x16xi32, 1> + return +}