diff --git a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_lhlo/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_lhlo/hlo_legalize_to_lhlo.cc index 6bca4a6f814..14c93dd3aa6 100644 --- a/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_lhlo/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_lhlo/hlo_legalize_to_lhlo.cc @@ -348,6 +348,69 @@ struct HloToLhloReturnOpConverter : public BaseOpConversion { } }; +// Legalize mhlo.scatter to a lmhlo.scatter +struct HloToLhloScatterOpConverter : public BaseOpConversion { + using BaseOpConversion::BaseOpConversion; + + LogicalResult matchAndRewrite( + mhlo::ScatterOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final { + + auto loc = op->getLoc(); + if (!llvm::hasSingleElement(op.getUpdateComputation())) { + return op->emitOpError() + << "tensor to buffer conversion expects a single block " + "in the region containing the operation"; + } + SmallVector bufferArgs(adaptor.getOperands()); + if (failed(convertResults(op, bufferArgs, rewriter))) return failure(); + auto newOp = rewriter.create>( + loc, std::nullopt, bufferArgs, op->getAttrs()); + + // Copy over the operations inside the region. + rewriter.inlineRegionBefore(op.getUpdateComputation(), newOp.getUpdateComputation(), + newOp.getUpdateComputation().end()); + + // Convert the region signature to memref and add extra result. + auto& entryBlock = newOp.getUpdateComputation().front(); + TypeConverter::SignatureConversion sigConversion( + adaptor.getOperands().size()); + for (auto arg : entryBlock.getArguments()) { + auto oldType = arg.getType().template cast(); + auto newType = + MemRefType::get(oldType.getShape(), oldType.getElementType()); + sigConversion.addInputs(arg.getArgNumber(), newType); + } + auto returnOp = cast(entryBlock.getTerminator()); + if (auto tupleTy = returnOp.getResults() + .front() + .getType() + .template dyn_cast()) { + auto* tupleOp = returnOp.getODSOperands(0).front().getDefiningOp(); + returnOp.getOperation()->dropAllReferences(); + rewriter.eraseOp(tupleOp); + returnOp.getOperation()->setOperands(tupleOp->getOperands()); + for (auto ty : tupleTy) { + auto tensorTy = ty.template cast(); + sigConversion.addInputs( + MemRefType::get(tensorTy.getShape(), tensorTy.getElementType())); + } + } else { + for (auto result : returnOp.getResults()) { + auto resultType = result.getType().template cast(); + sigConversion.addInputs({MemRefType::get(resultType.getShape(), + resultType.getElementType())}); + } + } + rewriter.applySignatureConversion(&newOp.getUpdateComputation(), sigConversion); + + rewriter.replaceOp( + op, ArrayRef(bufferArgs).slice(adaptor.getOperands().size())); + + return success(); + } +}; + // Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary // buffers if necessary. // @@ -515,7 +578,7 @@ void populateHloToLhloConversionPattern( // clang-format off patterns->add< HloToLhloCustomCallOpConverter, - // HloToLhloDotGeneralOpConverter, + //HloToLhloDotGeneralOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, @@ -577,7 +640,8 @@ void populateHloToLhloConversionPattern( HloToLhloOpConverter, HloToLhloReduceLikeOpConverter, HloToLhloReduceLikeOpConverter, - HloToLhloReturnOpConverter + HloToLhloReturnOpConverter, + HloToLhloScatterOpConverter >(*converter, context); // clang-format on }