Skip to content

Commit

Permalink
support op conversion from mhlo::scatter to lmhlo::scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong authored and Pokemons386 committed Nov 22, 2023
1 parent 8621caf commit f71c531
Showing 1 changed file with 66 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,69 @@ struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
}
};

// Legalize mhlo.scatter to a lmhlo.scatter
struct HloToLhloScatterOpConverter : public BaseOpConversion<mhlo::ScatterOp> {
using BaseOpConversion<mhlo::ScatterOp>::BaseOpConversion;

LogicalResult matchAndRewrite(
mhlo::ScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {

auto loc = op->getLoc();
if (!llvm::hasSingleElement(op.getUpdateComputation())) {
return op->emitOpError()
<< "tensor to buffer conversion expects a single block "
"in the region containing the operation";
}
SmallVector<Value, 4> bufferArgs(adaptor.getOperands());
if (failed(convertResults(op, bufferArgs, rewriter))) return failure();
auto newOp = rewriter.create<mhlo::HloToLhloOp<ScatterOp>>(
loc, std::nullopt, bufferArgs, op->getAttrs());

// Copy over the operations inside the region.
rewriter.inlineRegionBefore(op.getUpdateComputation(), newOp.getUpdateComputation(),
newOp.getUpdateComputation().end());

// Convert the region signature to memref and add extra result.
auto& entryBlock = newOp.getUpdateComputation().front();
TypeConverter::SignatureConversion sigConversion(
adaptor.getOperands().size());
for (auto arg : entryBlock.getArguments()) {
auto oldType = arg.getType().template cast<TensorType>();
auto newType =
MemRefType::get(oldType.getShape(), oldType.getElementType());
sigConversion.addInputs(arg.getArgNumber(), newType);
}
auto returnOp = cast<mhlo::ReturnOp>(entryBlock.getTerminator());
if (auto tupleTy = returnOp.getResults()
.front()
.getType()
.template dyn_cast<TupleType>()) {
auto* tupleOp = returnOp.getODSOperands(0).front().getDefiningOp();
returnOp.getOperation()->dropAllReferences();
rewriter.eraseOp(tupleOp);
returnOp.getOperation()->setOperands(tupleOp->getOperands());
for (auto ty : tupleTy) {
auto tensorTy = ty.template cast<TensorType>();
sigConversion.addInputs(
MemRefType::get(tensorTy.getShape(), tensorTy.getElementType()));
}
} else {
for (auto result : returnOp.getResults()) {
auto resultType = result.getType().template cast<TensorType>();
sigConversion.addInputs({MemRefType::get(resultType.getShape(),
resultType.getElementType())});
}
}
rewriter.applySignatureConversion(&newOp.getUpdateComputation(), sigConversion);

rewriter.replaceOp(
op, ArrayRef<Value>(bufferArgs).slice(adaptor.getOperands().size()));

return success();
}
};

// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
// buffers if necessary.
//
Expand Down Expand Up @@ -515,7 +578,7 @@ void populateHloToLhloConversionPattern(
// clang-format off
patterns->add<
HloToLhloCustomCallOpConverter,
// HloToLhloDotGeneralOpConverter,
//HloToLhloDotGeneralOpConverter,
HloToLhloOpConverter<mhlo::AbsOp>,
HloToLhloOpConverter<mhlo::AddOp>,
HloToLhloOpConverter<mhlo::AndOp>,
Expand Down Expand Up @@ -577,7 +640,8 @@ void populateHloToLhloConversionPattern(
HloToLhloOpConverter<mhlo::DynamicUpdateSliceOp>,
HloToLhloReduceLikeOpConverter<mhlo::ReduceOp>,
HloToLhloReduceLikeOpConverter<mhlo::ReduceWindowOp>,
HloToLhloReturnOpConverter
HloToLhloReturnOpConverter,
HloToLhloScatterOpConverter
>(*converter, context);
// clang-format on
}
Expand Down

0 comments on commit f71c531

Please sign in to comment.