diff --git a/include/Dialect/Secret/IR/SecretOps.h b/include/Dialect/Secret/IR/SecretOps.h index 835f063fc..b3a960d7e 100644 --- a/include/Dialect/Secret/IR/SecretOps.h +++ b/include/Dialect/Secret/IR/SecretOps.h @@ -5,6 +5,7 @@ #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/include/mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project #include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project diff --git a/include/Dialect/Secret/IR/SecretOps.td b/include/Dialect/Secret/IR/SecretOps.td index aec5f29f7..1ce2d6b44 100644 --- a/include/Dialect/Secret/IR/SecretOps.td +++ b/include/Dialect/Secret/IR/SecretOps.td @@ -164,6 +164,34 @@ def Secret_GenericOp : Secret_Op<"generic", [ // value. Returns nullptr if the value is not a block argument for this // secret.generic. OpOperand *getOpOperandForBlockArgument(Value value); + + // Clones a generic op and adds new yielded values. Returns the new op and + // the value range corresponding to the new result values of the generic. + // Callers can follow this method with something like the following to + // replace the current generic op with the result of this method. + // + // auto [modifiedGeneric, newResults] = + // genericOp.addNewYieldedValues(newResults, rewriter); + // rewriter.replaceOp( + // genericOp, + // ValueRange(modifiedGeneric.getResults() + // .drop_back(newResults.size()))); + std::pair addNewYieldedValues( + ValueRange newValuesToYield, PatternRewriter &rewriter); + + // Clones the current op with the yielded values in `yieldedValuesToRemove` + // removed. Users can replace the current op with the result of this method + // as follows: + // + // SmallVector remainingResults; + // auto modifiedGeneric = + // op.removeYieldedValues(valuesToRemove, rewriter, remainingResults); + // rewriter.replaceAllUsesWith(remainingResults, modifiedGeneric.getResults()); + // rewriter.eraseOp(op); + GenericOp removeYieldedValues( + ValueRange yieldedValuesToRemove, + PatternRewriter &rewriter, + SmallVector &remainingResults); }]; let hasCanonicalizer = 1; diff --git a/include/Dialect/Secret/IR/SecretPatterns.h b/include/Dialect/Secret/IR/SecretPatterns.h index d58bcfd10..1425d7d55 100644 --- a/include/Dialect/Secret/IR/SecretPatterns.h +++ b/include/Dialect/Secret/IR/SecretPatterns.h @@ -64,6 +64,40 @@ struct RemoveUnusedGenericArgs : public OpRewritePattern { PatternRewriter &rewriter) const override; }; +// Remove unused yields of a secret.generic op +// +// E.g., +// +// %res0, %res1 = secret.generic +// { +// ^bb0(%used: i32, %unused: i32): +// %0 = arith.constant 1 : i32 +// %1 = arith.constant 1 : i32 +// secret.yield %0, %1 : i32, i32 +// } -> (!secret.secret, !secret.secret) +// ... ... +// +// is transformed to +// +// %res0 = secret.generic +// ins(%value_sec : !secret.secret) { +// ^bb0(%used: i32): +// %0 = arith.constant 1 : i32 +// %1 = arith.constant 1 : i32 +// secret.yield %0, : i32 +// } -> (!secret.secret) +// +// The dead code elimination pass then removes any subsequent unused ops inside +// the generic. +struct RemoveUnusedYieldedValues : public OpRewritePattern { + RemoveUnusedYieldedValues(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/2) {} + + public: + LogicalResult matchAndRewrite(GenericOp op, + PatternRewriter &rewriter) const override; +}; + // Remove non-secret args of a secret.generic op, since they can be referenced // directly in the enclosing scope. struct RemoveNonSecretGenericArgs : public OpRewritePattern { diff --git a/lib/Dialect/Secret/IR/SecretOps.cpp b/lib/Dialect/Secret/IR/SecretOps.cpp index e1561d630..c11935547 100644 --- a/lib/Dialect/Secret/IR/SecretOps.cpp +++ b/lib/Dialect/Secret/IR/SecretOps.cpp @@ -9,8 +9,10 @@ #include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project #include "mlir/include/mlir/IR/Block.h" // from @llvm-project #include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/include/mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/include/mlir/IR/Region.h" // from @llvm-project #include "mlir/include/mlir/IR/Types.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project @@ -255,10 +257,71 @@ YieldOp GenericOp::getYieldOp() { return *getBody()->getOps().begin(); } +GenericOp cloneWithNewTypes(GenericOp op, TypeRange newTypes, + PatternRewriter &rewriter) { + return rewriter.create( + op.getLoc(), op.getOperands(), newTypes, + [&](OpBuilder &b, Location loc, ValueRange blockArguments) { + IRMapping mp; + for (BlockArgument blockArg : op.getBody()->getArguments()) { + mp.map(blockArg, blockArguments[blockArg.getArgNumber()]); + } + for (auto &op : op.getBody()->getOperations()) { + b.clone(op, mp); + } + }); +} + +std::pair GenericOp::addNewYieldedValues( + ValueRange newValuesToYield, PatternRewriter &rewriter) { + YieldOp yieldOp = getYieldOp(); + yieldOp.getValuesMutable().append(newValuesToYield); + auto newTypes = llvm::to_vector<4>( + llvm::map_range(yieldOp.getValues().getTypes(), [](Type t) -> Type { + SecretType newTy = secret::SecretType::get(t); + return newTy; + })); + GenericOp newOp = cloneWithNewTypes(*this, newTypes, rewriter); + + auto newResultStartIter = newOp.getResults().drop_front( + newOp.getNumResults() - newValuesToYield.size()); + + return {newOp, ValueRange(newResultStartIter)}; +} + +GenericOp GenericOp::removeYieldedValues(ValueRange yieldedValuesToRemove, + PatternRewriter &rewriter, + SmallVector &remainingResults) { + YieldOp yieldOp = getYieldOp(); + for ([[maybe_unused]] Value v : yieldedValuesToRemove) { + assert(llvm::is_contained(yieldOp.getValues(), v) && + "Cannot remove a value that is not yielded"); + } + + for (int i = 0; i < getYieldOp()->getNumOperands(); ++i) { + Value result = getResults()[i]; + if (result.use_empty()) { + getYieldOp().getValuesMutable().erase(i); + // Ensure the next iteration uses the right arg number + --i; + } else { + remainingResults.push_back(result); + } + } + + auto newResultTypes = llvm::to_vector<4>( + llvm::map_range(yieldOp.getValues().getTypes(), [](Type t) -> Type { + SecretType newTy = secret::SecretType::get(t); + return newTy; + })); + + return cloneWithNewTypes(*this, newResultTypes, rewriter); +} + void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } } // namespace secret diff --git a/lib/Dialect/Secret/IR/SecretPatterns.cpp b/lib/Dialect/Secret/IR/SecretPatterns.cpp index c47426385..0d980bbc0 100644 --- a/lib/Dialect/Secret/IR/SecretPatterns.cpp +++ b/lib/Dialect/Secret/IR/SecretPatterns.cpp @@ -43,6 +43,27 @@ LogicalResult RemoveUnusedGenericArgs::matchAndRewrite( return hasUnusedOps ? success() : failure(); } +LogicalResult RemoveUnusedYieldedValues::matchAndRewrite( + GenericOp op, PatternRewriter &rewriter) const { + SmallVector valuesToRemove; + for (auto &opOperand : op.getYieldOp()->getOpOperands()) { + Value result = op.getResults()[opOperand.getOperandNumber()]; + if (result.use_empty()) { + valuesToRemove.push_back(opOperand.get()); + } + } + + if (!valuesToRemove.empty()) { + SmallVector remainingResults; + auto modifiedGeneric = + op.removeYieldedValues(valuesToRemove, rewriter, remainingResults); + rewriter.replaceAllUsesWith(remainingResults, modifiedGeneric.getResults()); + rewriter.eraseOp(op); + return success(); + } + return failure(); +} + LogicalResult RemoveNonSecretGenericArgs::matchAndRewrite( GenericOp op, PatternRewriter &rewriter) const { bool deletedAny = false; diff --git a/lib/Dialect/Secret/Transforms/DistributeGeneric.cpp b/lib/Dialect/Secret/Transforms/DistributeGeneric.cpp index 5b53ad3c9..4201aab85 100644 --- a/lib/Dialect/Secret/Transforms/DistributeGeneric.cpp +++ b/lib/Dialect/Secret/Transforms/DistributeGeneric.cpp @@ -237,48 +237,6 @@ struct SplitGeneric : public OpRewritePattern { // RegionBranchOpInterface (scf.while, scf.if). } - // Adds new value to yield in a generic op. This requires replacing the entire - // generic op and cloning all its ops, since I can't find a way to modify the - // op's return type in-place. - // - // Returns the new op and the value range corresponding to the new result - // values of the generic. - std::pair addNewYieldedValues( - GenericOp genericOp, ValueRange newValuesToYield, - PatternRewriter &rewriter) const { - YieldOp yieldOp = genericOp.getYieldOp(); - yieldOp.getValuesMutable().append(newValuesToYield); - auto newTypes = llvm::to_vector<4>( - llvm::map_range(yieldOp.getValues().getTypes(), [](Type t) -> Type { - SecretType newTy = secret::SecretType::get(t); - LLVM_DEBUG(llvm::dbgs() << "Adding new type: " << newTy << "\n"); - return newTy; - })); - GenericOp newOp = rewriter.create( - genericOp.getLoc(), genericOp.getOperands(), newTypes, - [&](OpBuilder &b, Location loc, ValueRange blockArguments) { - IRMapping mp; - for (BlockArgument blockArg : genericOp.getBody()->getArguments()) { - mp.map(blockArg, blockArguments[blockArg.getArgNumber()]); - } - for (auto &op : genericOp.getBody()->getOperations()) { - LLVM_DEBUG(llvm::dbgs() << "Cloning " << op.getName() << "\n"); - b.clone(op, mp); - } - }); - - LLVM_DEBUG(newOp.emitRemark() << "Cloned generic Op with new results\n"); - - auto newResultStartIter = newOp.getResults().drop_front( - newOp.getNumResults() - newValuesToYield.size()); - - LLVM_DEBUG(llvm::dbgs() << "Replacing old op\n"); - rewriter.replaceOp( - genericOp, - ValueRange(newOp.getResults().drop_back(newValuesToYield.size()))); - return {newOp, ValueRange(newResultStartIter)}; - } - /// Move an op from the body of one secret.generic to an earlier /// secret.generic in the same block. Updates the yielded values and operands /// of the secret.generics appropriately. @@ -376,7 +334,10 @@ struct SplitGeneric : public OpRewritePattern { Operation *clonedOp = rewriter.clone(opToMove, cloningMp); clonedOp->moveBefore(targetGeneric.getYieldOp()); auto [modifiedGeneric, newResults] = - addNewYieldedValues(targetGeneric, clonedOp->getResults(), rewriter); + targetGeneric.addNewYieldedValues(clonedOp->getResults(), rewriter); + rewriter.replaceOp( + targetGeneric, + ValueRange(modifiedGeneric.getResults().drop_back(newResults.size()))); LLVM_DEBUG(modifiedGeneric.emitRemark() << "Added new yielded values to target generic\n"); diff --git a/tests/secret/canonicalize.mlir b/tests/secret/canonicalize.mlir new file mode 100644 index 000000000..5af109363 --- /dev/null +++ b/tests/secret/canonicalize.mlir @@ -0,0 +1,16 @@ +// RUN: heir-opt --canonicalize %s | FileCheck %s + +// CHECK-LABEL: func @remove_unused_yielded_values +func.func @remove_unused_yielded_values(%arg0: !secret.secret) -> !secret.secret { + %X = arith.constant 7 : i32 + %Y = secret.conceal %X : i32 -> !secret.secret + %Z, %UNUSED = secret.generic + ins(%Y, %arg0 : !secret.secret, !secret.secret) { + ^bb0(%y: i32, %clear_arg0 : i32) : + %d = arith.addi %clear_arg0, %y: i32 + %unused = arith.addi %y, %y: i32 + // CHECK: secret.yield %[[value:.*]] : i32 + secret.yield %d, %unused : i32, i32 + } -> (!secret.secret, !secret.secret) + func.return %Z : !secret.secret +}