Skip to content

Commit

Permalink
[BACKEND] Add support for reshape op (#2676)
Browse files Browse the repository at this point in the history
Generalize the view op into a reshape op with an attribute deciding
whether re-ordering elements is allowed.
When re-ordering element is not allowed we currently only handle trivial
block layout and makes sure none of the passes generate a different
layout.
  • Loading branch information
ThomasRaoux authored Nov 23, 2023
1 parent fab5bcc commit 607190f
Show file tree
Hide file tree
Showing 20 changed files with 184 additions and 90 deletions.
16 changes: 6 additions & 10 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -302,19 +302,15 @@ def TT_ExpandDimsOp : TT_Op<"expand_dims", [Pure,
let hasFolder = 1;
}

// view is not `pure` because it may reorder elements
def TT_ViewOp : TT_Op<"view", [NoMemoryEffect,
SameOperandsAndResultElementType]> {
let summary = "view";

let arguments = (ins TT_Tensor:$src);

def TT_ReshapeOp : TT_Op<"reshape", [Pure,
SameOperandsAndResultElementType]> {
let summary = "reinterpret a tensor to a different shape. It may change elements order if the attribute is set.";
let arguments = (ins TT_Tensor:$src, BoolAttr:$allow_reorder);
let results = (outs TT_Tensor:$result);

let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)";

let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
let hasCanonicalizeMethod = 1;
let hasFolder = 1;
let hasVerifier = 1;
}

def TT_BroadcastOp : TT_Op<"broadcast", [Pure,
Expand Down
6 changes: 6 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ bool isExpensiveCat(CatOp cat, Attribute targetEncoding);
// Return true if a view between the two types cannot be implemented as a no-op.
bool isExpensiveView(Type srcType, Type dstType);

// Return a blocked encoding where the shape is distributed contiguously amonsgt
// the threads, warps, CTAs with 1 element per threads.
triton::gpu::BlockedEncodingAttr
getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
int numWarps, int threadsPerWarp, int numCTAs);

} // namespace gpu
} // namespace triton
} // namespace mlir
Expand Down
32 changes: 25 additions & 7 deletions lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,38 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
}
};

struct ViewOpConversion : public ConvertTritonGPUOpToLLVMPattern<ViewOp> {
using OpAdaptor = typename ViewOp::Adaptor;
explicit ViewOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
struct ReshapeOpConversion : public ConvertTritonGPUOpToLLVMPattern<ReshapeOp> {
using OpAdaptor = typename ReshapeOp::Adaptor;
explicit ReshapeOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,

PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<ViewOp>(typeConverter, benefit) {}
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<ReshapeOp>(typeConverter, benefit) {}

LogicalResult
matchAndRewrite(ViewOp op, OpAdaptor adaptor,
matchAndRewrite(ReshapeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
assert(!triton::gpu::isExpensiveView(op.getSrc().getType(), op.getType()) &&
"expensive view not supported");
auto resultTy = op.getType().template cast<RankedTensorType>();
auto srcTy = op.getSrc().getType().template cast<RankedTensorType>();
if (!op.getAllowReorder()) {
// Only support trivial block layouts for now.
auto mod = op->getParentOfType<ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
int threadsPerWarp =
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
assert(resultTy.getEncoding() == triton::gpu::getDefaultBlockedEncoding(
op.getContext(), resultTy.getShape(),
numWarps, threadsPerWarp, numCTAs) &&
"ReshapeOp lowering only support block encoding right now.");
assert(srcTy.getEncoding() == triton::gpu::getDefaultBlockedEncoding(
op.getContext(), srcTy.getShape(),
numWarps, threadsPerWarp, numCTAs) &&
"ReshapeOp lowering only support block encoding right now.");
}

auto vals = this->getTypeConverter()->unpackLLElements(
loc, adaptor.getSrc(), rewriter, op.getOperand().getType());
Value ret =
Expand Down Expand Up @@ -230,7 +248,7 @@ void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
ModuleAxisInfoAnalysis &axisInfoAnalysis,
ModuleAllocation &allocation,
PatternBenefit benefit) {
patterns.add<ViewOpConversion>(typeConverter, benefit);
patterns.add<ReshapeOpConversion>(typeConverter, benefit);
patterns.add<ExpandDimsOpConversion>(typeConverter, benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit);
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
// layout
GenericOpPattern<triton::AdvanceOp>,
GenericOpPattern<triton::MakeTensorPtrOp>,
GenericOpPattern<triton::ViewOp>, GenericOpPattern<triton::BitcastOp>,
GenericOpPattern<triton::ReshapeOp>, GenericOpPattern<triton::BitcastOp>,
GenericOpPattern<triton::FpToFpOp>, GenericOpPattern<triton::IntToPtrOp>,
GenericOpPattern<triton::PtrToIntOp>, GenericOpPattern<triton::SplatOp>,
TritonBroadcastPattern, GenericOpPattern<triton::AddPtrOp>,
Expand Down
26 changes: 20 additions & 6 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ OpFoldResult ExpandDimsOp::fold(FoldAdaptor adaptor) {
return foldViewLikeOp(*this, adaptor.getSrc());
}

//-- ViewOp --
//-- ReshapeOp --
template <typename OpType>
LogicalResult canonicalizeViewOrBroadcast(OpType op,
PatternRewriter &rewriter) {
Expand All @@ -744,9 +744,10 @@ LogicalResult canonicalizeViewOrBroadcast(OpType op,
}

// view(view) -> view
if (auto parent_view = dyn_cast<OpType>(definingOp)) {
rewriter.replaceOpWithNewOp<OpType>(op, op.getType(),
parent_view.getOperand());
if (auto parentView = dyn_cast<OpType>(definingOp)) {
rewriter.replaceOpWithNewOp<OpType>(op, TypeRange({op.getType()}),
parentView->getOperands(),
parentView->getAttrs());
return mlir::success();
}

Expand All @@ -759,11 +760,14 @@ LogicalResult canonicalizeViewOrBroadcast(OpType op,

return mlir::failure();
}
LogicalResult ViewOp::canonicalize(ViewOp op, PatternRewriter &rewriter) {

LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) {
if (!op.getAllowReorder())
return failure();
return canonicalizeViewOrBroadcast(op, rewriter);
}

OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (getType() == getOperand().getType()) {
// no-op
return getOperand();
Expand All @@ -772,6 +776,16 @@ OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
return foldViewLikeOp(*this, adaptor.getSrc());
}

mlir::LogicalResult mlir::triton::ReshapeOp::verify() {
auto dstType = getType().cast<RankedTensorType>();
auto srcType = getSrc().getType().cast<RankedTensorType>();
if (dstType.getNumElements() != srcType.getNumElements()) {
return emitError(
"number of src and dst elements of reshape must be the same");
}
return mlir::success();
}

//-- BroadcastOp --
LogicalResult BroadcastOp::canonicalize(BroadcastOp op,
PatternRewriter &rewriter) {
Expand Down
45 changes: 34 additions & 11 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,23 @@ BlockedEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}

// 1 element per thread
// order = reverse(arange(rank))
triton::gpu::BlockedEncodingAttr
getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
int numWarps, int threadsPerWarp, int numCTAs) {
int rank = shape.size();
llvm::SmallVector<unsigned> order(rank);
std::iota(order.begin(), order.end(), 0);
std::reverse(order.begin(), order.end());
llvm::SmallVector<unsigned> sizePerThread(rank, 1);
triton::gpu::BlockedEncodingAttr encoding =
triton::gpu::BlockedEncodingAttr::get(context, shape, sizePerThread,
order, numWarps, threadsPerWarp,
numCTAs);
return encoding;
}

} // namespace gpu
} // namespace triton
} // namespace mlir
Expand Down Expand Up @@ -1741,13 +1758,14 @@ struct TritonGPUInferLayoutInterface
//===----------------------------------------------------------------------===//

struct CanonicalizeConvertFromView
: public mlir::OpRewritePattern<triton::ViewOp> {
: public mlir::OpRewritePattern<triton::ReshapeOp> {

CanonicalizeConvertFromView(MLIRContext *context)
: OpRewritePattern<triton::ViewOp>(context, 1) {}
: OpRewritePattern<triton::ReshapeOp>(context, 1) {}

mlir::LogicalResult
matchAndRewrite(triton::ViewOp op, PatternRewriter &rewriter) const override {
matchAndRewrite(triton::ReshapeOp op,
PatternRewriter &rewriter) const override {
Operation *arg = op->getOperand(0).getDefiningOp();
if (!arg)
return mlir::failure();
Expand All @@ -1756,9 +1774,12 @@ struct CanonicalizeConvertFromView
return failure();
if (isExpensiveView(convert.getOperand().getType(), op.getType()))
return failure();
// view(convert) -> view
rewriter.replaceOpWithNewOp<triton::ViewOp>(op, op->getResult(0).getType(),
convert.getOperand());
if (!op.getAllowReorder())
return failure();
// reshape(cvt)->reshape
rewriter.replaceOpWithNewOp<triton::ReshapeOp>(
op, op->getResult(0).getType(), convert.getOperand(),
op.getAllowReorder());
return mlir::success();
}
};
Expand Down Expand Up @@ -1802,9 +1823,10 @@ struct CanonicalizeConvertFromConvert
// block argument
if (!arg)
return mlir::failure();
// cvt(view) -> view
if (auto view = dyn_cast<triton::ViewOp>(arg)) {
if (isExpensiveView(view.getOperand().getType(), op.getType()))
// cvt(reshape) -> reshape
if (auto reshape = dyn_cast<triton::ReshapeOp>(arg)) {
if (!reshape.getAllowReorder() ||
isExpensiveView(reshape.getOperand().getType(), op.getType()))
return failure();
// In TritonGPUToLLVM phase, ViewOp is converted to unpacking and packing
// operations, which requires the element type to match between unpacking
Expand All @@ -1815,8 +1837,9 @@ struct CanonicalizeConvertFromConvert
if (hasDotOperandEncoding(op->getOperand(0)) ||
hasDotOperandEncoding(op->getResult(0)))
return failure();
rewriter.replaceOpWithNewOp<triton::ViewOp>(
op, op->getResult(0).getType(), view.getResult());
rewriter.replaceOpWithNewOp<triton::ReshapeOp>(
op, op->getResult(0).getType(), reshape.getResult(),
reshape.getAllowReorder());
return mlir::success();
}
// cvt(cat) -> cat
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ class TritonGPUOptimizeThreadLocalityPass
builder.setInsertionPointAfter(reduce);
IRMapping mapping;
for (auto operand : reduce.getOperands()) {
auto viewOp = builder.create<triton::ViewOp>(reduce.getLoc(),
viewOpTensorType, operand);
auto viewOp = builder.create<triton::ReshapeOp>(
reduce.getLoc(), viewOpTensorType, operand, /*allowReorder=*/true);
mapping.map(operand, viewOp);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ static bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
static bool isLayoutAnchor(Operation *op) {
if (isa<triton::LoadOp, triton::StoreOp>(op))
return isExpensiveLoadOrStore(op);
if (isa<triton::ViewOp, triton::DotOp, triton::AtomicRMWOp,
if (isa<triton::ReshapeOp, triton::DotOp, triton::AtomicRMWOp,
triton::AtomicCASOp>(op))
return true;
return false;
Expand Down
14 changes: 4 additions & 10 deletions lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "mlir/IR/IRMapping.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include <algorithm>
#include <numeric>

Expand All @@ -24,17 +25,10 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
// TODO: check for layout encodings more specifically
if (tensorType.getEncoding())
return tensorType;
// pessimistic values for attributes:
// - 1 element per thread
// - order = arange(rank)
ArrayRef<int64_t> shape = tensorType.getShape();
int rank = shape.size();
llvm::SmallVector<unsigned> order(rank);
std::iota(order.begin(), order.end(), 0);
llvm::SmallVector<unsigned> sizePerThread(rank, 1);
Attribute encoding = triton::gpu::BlockedEncodingAttr::get(
this->context, shape, sizePerThread, order, this->numWarps,
this->threadsPerWarp, this->numCTAs);
triton::gpu::BlockedEncodingAttr encoding =
getDefaultBlockedEncoding(this->context, shape, this->numWarps,
this->threadsPerWarp, this->numCTAs);
return RankedTensorType::get(shape, tensorType.getElementType(), encoding);
});

Expand Down
17 changes: 10 additions & 7 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ std::optional<Attribute> inferSrcEncoding(Operation *op, Attribute encoding) {
return inferSrcEncoding(reduceOp, encoding);
if (auto expand = dyn_cast<triton::ExpandDimsOp>(op))
return inferSrcEncoding(expand, encoding);
if (isa<triton::ViewOp, triton::CatOp>(op))
if (isa<triton::ReshapeOp, triton::CatOp>(op))
return std::nullopt;
return encoding;
}
Expand All @@ -289,7 +289,7 @@ std::optional<Attribute> inferDstEncoding(Operation *op, Attribute encoding) {
return inferDstEncoding(reduceOp, encoding);
if (auto expand = dyn_cast<triton::ExpandDimsOp>(op))
return inferDstEncoding(expand, encoding);
if (isa<triton::ViewOp, triton::CatOp>(op))
if (isa<triton::ReshapeOp, triton::CatOp>(op))
return std::nullopt;
return encoding;
}
Expand Down Expand Up @@ -344,11 +344,14 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) {
}
return true;
}
if (auto view = dyn_cast<triton::ViewOp>(op)) {
auto viewDstType = view.getType().cast<RankedTensorType>();
RankedTensorType newDstType = RankedTensorType::get(
viewDstType.getShape(), viewDstType.getElementType(), targetEncoding);
return !triton::gpu::isExpensiveView(view.getOperand().getType(),

if (auto reshape = dyn_cast<triton::ReshapeOp>(op)) {
auto reshapeDstType = reshape.getType().cast<RankedTensorType>();
RankedTensorType newDstType =
RankedTensorType::get(reshapeDstType.getShape(),
reshapeDstType.getElementType(), targetEncoding);
return reshape.getAllowReorder() &&
!triton::gpu::isExpensiveView(reshape.getOperand().getType(),
newDstType);
}
return isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
Expand Down
14 changes: 7 additions & 7 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1359,14 +1359,14 @@ void init_triton_ir(py::module &&m) {
self.create<mlir::triton::StoreOp>(ptrs, val, mask, cacheModifier,
evictionPolicy);
})
.def("create_view",
.def("create_reshape",
[](TritonOpBuilder &self, mlir::Value &arg,
std::vector<int64_t> &shape) -> mlir::Value {
auto argType = arg.getType()
.dyn_cast<mlir::RankedTensorType>()
.getElementType();
return self.create<mlir::triton::ViewOp>(
mlir::RankedTensorType::get(shape, argType), arg);
std::vector<int64_t> &shape, bool allowReorder) -> mlir::Value {
auto argType =
arg.getType().cast<mlir::RankedTensorType>().getElementType();
return self.create<mlir::triton::ReshapeOp>(
mlir::RankedTensorType::get(shape, argType), arg,
allowReorder);
})
.def(
"create_expand_dims",
Expand Down
29 changes: 29 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3060,6 +3060,35 @@ def kernel(X, s):
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8)


reshape_list = [((64, ), (8, 8)), ((2, 32), (16, 4)), ((512, ), (2, 2, 2, 2, 2, 2, 2, 2, 2)), ((64, 32), (16, 8, 16))]


@pytest.mark.parametrize("formats", reshape_list)
def test_reshape(formats, device):
in_format, out_format = formats

@triton.jit
def kernel(Z, X, out_tuple: tl.constexpr):
x = tl.load(X_PTR_EXPR)
z = tl.reshape(x, out_tuple)
tl.store(Z_PTR_EXPR, z)

def generate_kernel(shape_x, shape_z):
to_replace = {
'X_PTR_EXPR': make_ptr_str('X', shape_x),
'Z_PTR_EXPR': make_ptr_str('Z', shape_z),
}
return patch_kernel(kernel, to_replace)

x = numpy_random(in_format, dtype_str="int32")
z = x.reshape(out_format)
x_tri = to_triton(x, device=device)
patched_kernel = generate_kernel(in_format, out_format)
z_tri = to_triton(np.empty(out_format, dtype=np.int32), device=device)
patched_kernel[(1, )](z_tri, x_tri, out_format)
np.testing.assert_equal(z, to_numpy(z_tri))


# -------------
# test call
# -------------
Expand Down
Loading

0 comments on commit 607190f

Please sign in to comment.