Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
MaheshRavishankar authored Jan 27, 2024
1 parent 46a25d7 commit 28c7051
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 21 deletions.
2 changes: 1 addition & 1 deletion externals/llvm-project
8 changes: 3 additions & 5 deletions lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class AdjustCallingConventionForFunc
}
newResultTypes.push_back(type);
}
rewriter.updateRootInPlace(func, [&] {
rewriter.modifyOpInPlace(func, [&] {
func.setType(FunctionType::get(
getContext(), conversion.getConvertedTypes(), newResultTypes));
// Clear out the type bounds, now that the type incorporates them.
Expand Down Expand Up @@ -194,14 +194,12 @@ static LogicalResult adjustCallingConventions(func::FuncOp func,
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion(
[](Torch::TupleType type,
SmallVectorImpl<Type> &types) -> LogicalResult {
[](Torch::TupleType type, SmallVectorImpl<Type> &types) -> LogicalResult {
llvm::append_range(types, type.getContainedTypes());
return success();
});
typeConverter.addConversion(
[](Torch::NoneType type,
SmallVectorImpl<Type> &types) -> LogicalResult {
[](Torch::NoneType type, SmallVectorImpl<Type> &types) -> LogicalResult {
return success();
});

Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock

// Replace return type of view-like ops with value-semantics type variant.
for (Operation *viewLikeOp : ops.viewLikeOps) {
rewriter.updateRootInPlace(viewLikeOp, [&] {
rewriter.modifyOpInPlace(viewLikeOp, [&] {
Value result = viewLikeOp->getResult(0);
auto resultType = result.getType().dyn_cast<NonValueTensorType>();
if (resultType)
Expand Down Expand Up @@ -337,7 +337,7 @@ class RewriteViewLikeSubgraph
// correctly copy them back to their mlir::func::ReturnOp's expected types.
DenseMap<Value, Type> originalTypes;
for (Operation *op : viewLikeOps) {
rewriter.updateRootInPlace(op, [&]() {
rewriter.modifyOpInPlace(op, [&]() {
if (auto nonValueTensorType =
op->getResult(0).getType().dyn_cast<NonValueTensorType>()) {
originalTypes[op->getResult(0)] = nonValueTensorType;
Expand Down
28 changes: 15 additions & 13 deletions lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

#include "PassDetail.h"

#include "ReifyAbstractInterpCalculationsUtils.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "ReifyAbstractInterpCalculationsUtils.h"
#include "llvm/ADT/StringExtras.h"

using namespace mlir;
Expand Down Expand Up @@ -72,8 +72,8 @@ namespace {
// immutable tensors.
class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
public:
ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context,
const std::optional<SymbolTable>& extraLibrary)
ConvertHasValueSemanticsOpsToValueTensors(
MLIRContext *context, const std::optional<SymbolTable> &extraLibrary)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {
this->extraLibrary = extraLibrary;
}
Expand All @@ -87,7 +87,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
return rewriter.notifyMatchFailure(op, "does not have value semantics");
}

rewriter.startRootUpdate(op);
rewriter.startOpModification(op);
// Convert all operands.
SmallVector<Value> newOperands;
for (OpOperand &opOperand : op->getOpOperands()) {
Expand All @@ -105,7 +105,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
auto listConstruct =
opOperand.get().getDefiningOp<PrimListConstructOp>();
if (!listConstruct) {
rewriter.cancelRootUpdate(op);
rewriter.cancelOpModification(op);
return rewriter.notifyMatchFailure(
op, "unimplemented: list of non vtensor type not constructed "
"from list construct");
Expand All @@ -120,7 +120,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
if (!llvm::all_of(listConstruct.getElements(), [](Value val) {
return val.getType().isa<NonValueTensorType, Torch::NoneType>();
})) {
rewriter.cancelRootUpdate(op);
rewriter.cancelOpModification(op);
return rewriter.notifyMatchFailure(
op, "unimplemented: list containing optional type is not "
"handled.");
Expand All @@ -138,7 +138,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {

Type newListType = getContainerOrTensorTypeWithValueSemantics(listType);
if (!newListType) {
rewriter.cancelRootUpdate(op);
rewriter.cancelOpModification(op);
return rewriter.notifyMatchFailure(
op, "Unable to convert list type to value semantics.");
}
Expand All @@ -154,7 +154,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
// from the non value tensor of the original optional value.
auto derefine = opOperand.get().getDefiningOp<DerefineOp>();
if (!derefine) {
rewriter.cancelRootUpdate(op);
rewriter.cancelOpModification(op);
return rewriter.notifyMatchFailure(
op, "unimplemented: optional of non vtensor type not from "
"derefine");
Expand All @@ -180,9 +180,10 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern {
rewriter.create<CopyToNonValueTensorOp>(op->getLoc(), result);
result.replaceAllUsesExcept(nonValueTensor, nonValueTensor);
}
rewriter.finalizeRootUpdate(op);
rewriter.finalizeOpModification(op);
return success();
}

private:
std::optional<SymbolTable> extraLibrary;
};
Expand Down Expand Up @@ -290,17 +291,18 @@ class ReduceTrailingUnderscoreInplaceVariant : public RewritePattern {
Operation *newOp = rewriter.create(state);
// Note: need to convert result to first input's dtype because mix precision
// compute would result in different behaviors.
// For example:
// a = torch.randn(3, 3).half() # float16
// b = torch.randn(3, 3) # float32
// For example:
// a = torch.randn(3, 3).half() # float16
// b = torch.randn(3, 3) # float32
// a += b # i.e. torch.ops.aten.add_(a, b), result is float16
// c = a + b # i.e. torch.ops.aten.add(a, b), result is float32
Value none = rewriter.create<ConstantNoneOp>(op->getLoc());
Value cstFalse = rewriter.create<ConstantBoolOp>(op->getLoc(), false);
auto aDtype = rewriter.create<PrimDtypeOp>(op->getLoc(), op->getOperand(0));
auto toDtype = rewriter.create<AtenToDtypeOp>(
op->getLoc(), newOp->getResult(0).getType(), newOp->getResult(0),
aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
auto tensor = rewriter.create<CopyToValueTensorOp>(op->getLoc(), toDtype);
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
op->getOperand(0));
Expand Down

0 comments on commit 28c7051

Please sign in to comment.