Skip to content

Commit

Permalink
Merge pull request google#334 from j2kun:remove-unused-yielded-value
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 590927742
  • Loading branch information
copybara-github committed Dec 14, 2023
2 parents 09578fd + 24344e5 commit afbfe51
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 45 deletions.
1 change: 1 addition & 0 deletions include/Dialect/Secret/IR/SecretOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project

Expand Down
28 changes: 28 additions & 0 deletions include/Dialect/Secret/IR/SecretOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,34 @@ def Secret_GenericOp : Secret_Op<"generic", [
// value. Returns nullptr if the value is not a block argument for this
// secret.generic.
OpOperand *getOpOperandForBlockArgument(Value value);

// Clones a generic op and adds new yielded values. Returns the new op and
// the value range corresponding to the new result values of the generic.
// Callers can follow this method with something like the following to
// replace the current generic op with the result of this method.
//
// auto [modifiedGeneric, newResults] =
// genericOp.addNewYieldedValues(newResults, rewriter);
// rewriter.replaceOp(
// genericOp,
// ValueRange(modifiedGeneric.getResults()
// .drop_back(newResults.size())));
std::pair<GenericOp, ValueRange> addNewYieldedValues(
ValueRange newValuesToYield, PatternRewriter &rewriter);

// Clones the current op with the yielded values in `yieldedValuesToRemove`
// removed. Users can replace the current op with the result of this method
// as follows:
//
// SmallVector<Value> remainingResults;
// auto modifiedGeneric =
// op.removeYieldedValues(valuesToRemove, rewriter, remainingResults);
// rewriter.replaceAllUsesWith(remainingResults, modifiedGeneric.getResults());
// rewriter.eraseOp(op);
GenericOp removeYieldedValues(
ValueRange yieldedValuesToRemove,
PatternRewriter &rewriter,
SmallVector<Value> &remainingResults);
}];

let hasCanonicalizer = 1;
Expand Down
34 changes: 34 additions & 0 deletions include/Dialect/Secret/IR/SecretPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,40 @@ struct RemoveUnusedGenericArgs : public OpRewritePattern<GenericOp> {
PatternRewriter &rewriter) const override;
};

// Remove unused yields of a secret.generic op
//
// E.g.,
//
// %res0, %res1 = secret.generic
// {
// ^bb0(%used: i32, %unused: i32):
// %0 = arith.constant 1 : i32
// %1 = arith.constant 1 : i32
// secret.yield %0, %1 : i32, i32
// } -> (!secret.secret<i32>, !secret.secret<i32>)
// ... <only use %res0> ...
//
// is transformed to
//
// %res0 = secret.generic
// ins(%value_sec : !secret.secret<i32>) {
// ^bb0(%used: i32):
// %0 = arith.constant 1 : i32
// %1 = arith.constant 1 : i32
// secret.yield %0, : i32
// } -> (!secret.secret<i32>)
//
// The dead code elimination pass then removes any subsequent unused ops inside
// the generic.
struct RemoveUnusedYieldedValues : public OpRewritePattern<GenericOp> {
RemoveUnusedYieldedValues(mlir::MLIRContext *context)
: OpRewritePattern<GenericOp>(context, /*benefit=*/2) {}

public:
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override;
};

// Remove non-secret args of a secret.generic op, since they can be referenced
// directly in the enclosing scope.
struct RemoveNonSecretGenericArgs : public OpRewritePattern<GenericOp> {
Expand Down
67 changes: 65 additions & 2 deletions lib/Dialect/Secret/IR/SecretOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Block.h" // from @llvm-project
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project
#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project
#include "mlir/include/mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Region.h" // from @llvm-project
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
Expand Down Expand Up @@ -255,10 +257,71 @@ YieldOp GenericOp::getYieldOp() {
return *getBody()->getOps<YieldOp>().begin();
}

GenericOp cloneWithNewTypes(GenericOp op, TypeRange newTypes,
PatternRewriter &rewriter) {
return rewriter.create<GenericOp>(
op.getLoc(), op.getOperands(), newTypes,
[&](OpBuilder &b, Location loc, ValueRange blockArguments) {
IRMapping mp;
for (BlockArgument blockArg : op.getBody()->getArguments()) {
mp.map(blockArg, blockArguments[blockArg.getArgNumber()]);
}
for (auto &op : op.getBody()->getOperations()) {
b.clone(op, mp);
}
});
}

std::pair<GenericOp, ValueRange> GenericOp::addNewYieldedValues(
ValueRange newValuesToYield, PatternRewriter &rewriter) {
YieldOp yieldOp = getYieldOp();
yieldOp.getValuesMutable().append(newValuesToYield);
auto newTypes = llvm::to_vector<4>(
llvm::map_range(yieldOp.getValues().getTypes(), [](Type t) -> Type {
SecretType newTy = secret::SecretType::get(t);
return newTy;
}));
GenericOp newOp = cloneWithNewTypes(*this, newTypes, rewriter);

auto newResultStartIter = newOp.getResults().drop_front(
newOp.getNumResults() - newValuesToYield.size());

return {newOp, ValueRange(newResultStartIter)};
}

GenericOp GenericOp::removeYieldedValues(ValueRange yieldedValuesToRemove,
PatternRewriter &rewriter,
SmallVector<Value> &remainingResults) {
YieldOp yieldOp = getYieldOp();
for ([[maybe_unused]] Value v : yieldedValuesToRemove) {
assert(llvm::is_contained(yieldOp.getValues(), v) &&
"Cannot remove a value that is not yielded");
}

for (int i = 0; i < getYieldOp()->getNumOperands(); ++i) {
Value result = getResults()[i];
if (result.use_empty()) {
getYieldOp().getValuesMutable().erase(i);
// Ensure the next iteration uses the right arg number
--i;
} else {
remainingResults.push_back(result);
}
}

auto newResultTypes = llvm::to_vector<4>(
llvm::map_range(yieldOp.getValues().getTypes(), [](Type t) -> Type {
SecretType newTy = secret::SecretType::get(t);
return newTy;
}));

return cloneWithNewTypes(*this, newResultTypes, rewriter);
}

void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CollapseSecretlessGeneric, RemoveUnusedGenericArgs,
RemoveNonSecretGenericArgs>(context);
results.add<CollapseSecretlessGeneric, RemoveUnusedYieldedValues,
RemoveUnusedGenericArgs, RemoveNonSecretGenericArgs>(context);
}

} // namespace secret
Expand Down
21 changes: 21 additions & 0 deletions lib/Dialect/Secret/IR/SecretPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,27 @@ LogicalResult RemoveUnusedGenericArgs::matchAndRewrite(
return hasUnusedOps ? success() : failure();
}

LogicalResult RemoveUnusedYieldedValues::matchAndRewrite(
GenericOp op, PatternRewriter &rewriter) const {
SmallVector<Value> valuesToRemove;
for (auto &opOperand : op.getYieldOp()->getOpOperands()) {
Value result = op.getResults()[opOperand.getOperandNumber()];
if (result.use_empty()) {
valuesToRemove.push_back(opOperand.get());
}
}

if (!valuesToRemove.empty()) {
SmallVector<Value> remainingResults;
auto modifiedGeneric =
op.removeYieldedValues(valuesToRemove, rewriter, remainingResults);
rewriter.replaceAllUsesWith(remainingResults, modifiedGeneric.getResults());
rewriter.eraseOp(op);
return success();
}
return failure();
}

LogicalResult RemoveNonSecretGenericArgs::matchAndRewrite(
GenericOp op, PatternRewriter &rewriter) const {
bool deletedAny = false;
Expand Down
47 changes: 4 additions & 43 deletions lib/Dialect/Secret/Transforms/DistributeGeneric.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,48 +278,6 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
// RegionBranchOpInterface (scf.while, scf.if).
}

// Adds new value to yield in a generic op. This requires replacing the entire
// generic op and cloning all its ops, since I can't find a way to modify the
// op's return type in-place.
//
// Returns the new op and the value range corresponding to the new result
// values of the generic.
std::pair<GenericOp, ValueRange> addNewYieldedValues(
GenericOp genericOp, ValueRange newValuesToYield,
PatternRewriter &rewriter) const {
YieldOp yieldOp = genericOp.getYieldOp();
yieldOp.getValuesMutable().append(newValuesToYield);
auto newTypes = llvm::to_vector<4>(
llvm::map_range(yieldOp.getValues().getTypes(), [](Type t) -> Type {
SecretType newTy = secret::SecretType::get(t);
LLVM_DEBUG(llvm::dbgs() << "Adding new type: " << newTy << "\n");
return newTy;
}));
GenericOp newOp = rewriter.create<GenericOp>(
genericOp.getLoc(), genericOp.getOperands(), newTypes,
[&](OpBuilder &b, Location loc, ValueRange blockArguments) {
IRMapping mp;
for (BlockArgument blockArg : genericOp.getBody()->getArguments()) {
mp.map(blockArg, blockArguments[blockArg.getArgNumber()]);
}
for (auto &op : genericOp.getBody()->getOperations()) {
LLVM_DEBUG(llvm::dbgs() << "Cloning " << op.getName() << "\n");
b.clone(op, mp);
}
});

LLVM_DEBUG(newOp.emitRemark() << "Cloned generic Op with new results\n");

auto newResultStartIter = newOp.getResults().drop_front(
newOp.getNumResults() - newValuesToYield.size());

LLVM_DEBUG(llvm::dbgs() << "Replacing old op\n");
rewriter.replaceOp(
genericOp,
ValueRange(newOp.getResults().drop_back(newValuesToYield.size())));
return {newOp, ValueRange(newResultStartIter)};
}

/// Move an op from the body of one secret.generic to an earlier
/// secret.generic in the same block. Updates the yielded values and operands
/// of the secret.generics appropriately.
Expand Down Expand Up @@ -417,7 +375,10 @@ struct SplitGeneric : public OpRewritePattern<GenericOp> {
Operation *clonedOp = rewriter.clone(opToMove, cloningMp);
clonedOp->moveBefore(targetGeneric.getYieldOp());
auto [modifiedGeneric, newResults] =
addNewYieldedValues(targetGeneric, clonedOp->getResults(), rewriter);
targetGeneric.addNewYieldedValues(clonedOp->getResults(), rewriter);
rewriter.replaceOp(
targetGeneric,
ValueRange(modifiedGeneric.getResults().drop_back(newResults.size())));
LLVM_DEBUG(modifiedGeneric.emitRemark()
<< "Added new yielded values to target generic\n");

Expand Down
16 changes: 16 additions & 0 deletions tests/secret/canonicalize.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: heir-opt --canonicalize %s | FileCheck %s

// CHECK-LABEL: func @remove_unused_yielded_values
func.func @remove_unused_yielded_values(%arg0: !secret.secret<i32>) -> !secret.secret<i32> {
%X = arith.constant 7 : i32
%Y = secret.conceal %X : i32 -> !secret.secret<i32>
%Z, %UNUSED = secret.generic
ins(%Y, %arg0 : !secret.secret<i32>, !secret.secret<i32>) {
^bb0(%y: i32, %clear_arg0 : i32) :
%d = arith.addi %clear_arg0, %y: i32
%unused = arith.addi %y, %y: i32
// CHECK: secret.yield %[[value:.*]] : i32
secret.yield %d, %unused : i32, i32
} -> (!secret.secret<i32>, !secret.secret<i32>)
func.return %Z : !secret.secret<i32>
}

0 comments on commit afbfe51

Please sign in to comment.