Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backend] Implement scaled_dot(mxfp4, fp8) #4904

Merged
merged 7 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,6 @@ chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order, int swizzleByteSize);

// FIXME
// Exposing to use it in LinearLayoutConversionsTest.cpp
// Remove it once we fully activate the DotOperand conversion via LLs
class DotOperandEncodingAttr;
LinearLayout ampereDotToLinearLayout(ArrayRef<int64_t> shape,
DotOperandEncodingAttr dot);
} // namespace mlir::triton::gpu

#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H
7 changes: 6 additions & 1 deletion lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,12 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,

assert(!isMfmaToDotShortcut(srcTy, dstTy));

auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
// FIXME This is NOT entirely correct
// This should be getElemOrder, but we don't have such a method
// TODO Implement getElemOrder and make sure it's consistent with
// getContigPerThread
auto inOrd = gpu::getThreadOrder(srcLayout);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we assume getElemOrder == getOrder

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getThreadOrder is same as getOrder except for AMD's AMDMfmaEncodingAttr. I haven't taken a deep investigation.
pin @zhanglx13 for expertise maybe

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See that I changed the definition of getThreadOrder in this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be specific I was referring to:

SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadOrder() const {
  auto order = ::getOrder(*this);
  if (getIsTransposed())
    std::swap(order[0], order[1]);
  return order;
}

I'm not sure if we should use getOrder or getThreadOrder for this encoding

auto outOrd = gpu::getThreadOrder(dstLayout);
scratchConfig.order = outOrd;

unsigned srcContigPerThread =
Expand Down
32 changes: 32 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,22 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
}
return true;
}
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
if (auto nvidiaMma =
dyn_cast<NvidiaMmaEncodingAttr>(dotOperand.getParent())) {
if (product(getCTAsPerCGA(nvidiaMma)) > 1) {
return false;
}
if (useLegacyMMAConversion) {
return false;
}
// FIXME [Dot LL]
// Enabling LL path for buggy kWidth path
bool largeKWidth =
dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64;
return largeKWidth && nvidiaMma.isAmpere();
}
}
if (isa<BlockedEncodingAttr>(layout)) {
return true;
}
Expand Down Expand Up @@ -460,6 +476,22 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
}
}

// FIXME [Dot LL]
// We know it's just for largeKWidth case in Ampere
// In this case, we need to pack the outputs into i32
if (isa<DotOperandEncodingAttr>(dstTy.getEncoding())) {
auto concat = [&](Value a, Value b) {
return or_(zext(i32_ty, bitcast(a, i16_ty)),
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
};

SmallVector<Value> outVals32(outVals.size() / 2);
for (int i = 0; i < outVals32.size(); ++i) {
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
}
outVals = outVals32;
}

Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
op.getType());
rewriter.replaceOp(op, result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
auto dstDotOp =
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
if (srcBlocked && dstDotOp) {
// FIXME [Dot LL]
// We support this one via LLs, as the LocalLoad path is buggy
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dstDotOp.getParent())) {
bool largeKWidth =
dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64;
if (mma.isAmpere() && largeKWidth) {
return;
}
}

Attribute sharedMemorySpace =
triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext());
auto tmpType = MemDescType::get(
Expand Down
68 changes: 31 additions & 37 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "mlir/Support/LLVM.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
Expand Down Expand Up @@ -234,8 +235,31 @@ static SmallVector<unsigned> eraseOrder(ArrayRef<unsigned> order,
return resOrder;
}

SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kMajor) {
// kMajor: if true, the matrix is fastest-running on k,
// otherwise it is on m (resp. n)
// opIdx=0: [batch, m, k] if rank == 3 else [m, k]
// opIdx=1: [batch, k, n] if rank == 3 else [k, n]
// batch (if rank == 3) is always the slowest running dimension
assert(rank == 2 || rank == 3);
assert(opIdx == 0 || opIdx == 1);
SmallVector<unsigned> order(rank);
std::iota(order.rbegin(), order.rend(), 0);
// If opIdx is 1 and kMajor is true, the order is [0, 1]
// (resp. [1, 2, 0] if rank == 3)
// Same if opIdx is 0 and kMajor is false
if (bool(opIdx) == kMajor) {
std::swap(order[0], order[1]);
}
return order;
}

SmallVector<unsigned> getWarpOrder(Attribute layout) {
auto order = getOrder(layout);
// FIXME: This mmaLayout if should just return
// getOrderForDotOperand(0, order.size(), kMajor=false)
// as mma has the same order as DotOperand(opIdx=0)
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(layout)) {
Jokeren marked this conversation as resolved.
Show resolved Hide resolved
if (mmaLayout.isHopper()) {
// Hopper MMA instructions force a warp order of [0, 1]. See docs:
Expand All @@ -245,40 +269,8 @@ SmallVector<unsigned> getWarpOrder(Attribute layout) {
order.insert(order.begin(), 0);
}
} else if (auto dotOpLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
// opIdx=0: [/*dim0*/batch, /*dim1=*/m, /*dim2=*/k] -> order=[1, 2, 0]
// opIdx=1: [/*dim0*/batch, /*dim1=*/k, /*dim2=*/n] -> order=[2, 1, 0]
std::iota(order.rbegin(), order.rend(), 0);
if (dotOpLayout.getOpIdx() == 0) {
std::swap(order[0], order[1]);
}
}
return order;
}

SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank) {
assert((rank == 2 || rank == 3) &&
"Invalid rank for dot operand order computation");
SmallVector<unsigned> order(rank);
// The 'order' field typically represents a descending sorted array of
// dimensions based on contiguity. For instance, in axisInfo utilities that
// retrieve tensor contiguity, it's assumed that the dimension with the
// highest contiguity corresponds to order[0].
//
// The relation between contiguity and order is only relevant if the layout
// interfaces with HBM, as is the case when we load tensor from HBM to
// registers in the dot layout to bypass LDS. When bypassing LDS, we make
// the following assumptions about tensor layouts:
// - Tensor A (opIdx == 0) is considered to be row-major.
// - Tensor B (opIdx == 1) is considered to be column-major.
//
// Based on these assumptions, we define the following orders:
// - For opIdx == 0, batch=dim0, m=dim1, and k=dim2, we assume an order of [2,
// 1, 0] for 3D tensors.
// - For opIdx == 1, batch=dim0, k=dim1, and n=dim2, we assume an order of [1,
// 2, 0] for 3D tensors.
std::iota(order.rbegin(), order.rend(), 0);
if (opIdx == 1) {
std::swap(order[0], order[1]);
order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(),
/*kMajor*/ false);
Comment on lines +272 to +273
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is kMajor always false here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is getting the warp order but not the element order. So m is the fastest changing dimension in opIdx=0. I think confusion may arise from the variable name kMajor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a suggestion for improvement though. Maybe just add some additional comments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, similarly to in wgmma, we want the warps have the exterior dimension (i.e. not K) as their fastest running dimension.

}
return order;
}
Expand All @@ -295,8 +287,8 @@ SmallVector<unsigned> getOrder(Attribute layout) {
return order;
}
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
auto rank = getWarpsPerCTA(dotLayout.getParent()).size();
return getOrderForDotOperand(dotLayout.getOpIdx(), rank);
auto rank = dotLayout.getWarpsPerCTA().size();
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true);
}
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
SmallVector<unsigned> parentOrder = getOrder(sliceLayout.getParent());
Expand Down Expand Up @@ -1048,7 +1040,8 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
return ::getWarpOrder(*this);
}
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
return ::getOrder(*this);
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
/*kMajor*/ true);
}
SmallVector<unsigned> DotOperandEncodingAttr::getShapePerCTATile(
ArrayRef<int64_t> tensorShape) const {
Expand Down Expand Up @@ -2019,6 +2012,7 @@ SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2RepForOperand(
ArrayRef<int64_t> shape, int bitwidth, int kWidth, int opIdx) const {
auto rank = shape.size();
auto warpsPerCTA = getWarpsPerCTA();

SmallVector<int> shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth};
int numRepBatch =
rank == 3
Expand Down
13 changes: 7 additions & 6 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -886,13 +886,14 @@ std::optional<LinearLayout>
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
if (auto mfmaLayout = llvm::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
return dotOperandMfmaToLinearLayout(*this, shape);
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
// FIXME [Dot LL]
// Do this unconditionally
auto largeKWidth = getKWidth() == 8;
if (mma.isAmpere() && largeKWidth) {
return ampereDotToLinearLayout(shape, *this);
}
}
// TODO Activate in a follow-up PR
// else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
// if (mma.isAmpere()) {
// return ampereDotToLinearLayout(shape, *this);
// }
//}
return std::nullopt;
}

Expand Down
31 changes: 18 additions & 13 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ LogicalResult UpcastMXFPOp::verify() {
auto xTy = getSrc().getType();
auto scaleTy = getScale().getType();

if (xTy.getElementType() != FloatType::getBF16(getContext())) {
return emitOpError("element type of the first operand must be bf16");
if (xTy.getElementType() != FloatType::getBF16(getContext()) &&
xTy.getElementType() != IntegerType::get(getContext(), 8)) {
return emitOpError("element type of the first operand must be bf16 or i8");
}

if (scaleTy.getElementType() != IntegerType::get(getContext(), 8)) {
Expand Down Expand Up @@ -72,7 +73,7 @@ LogicalResult UpcastMXFPOp::verify() {
}

LogicalResult UpcastMXFPOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
MLIRContext *ctx, std::optional<Location> loc, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties opaqueProperties,
RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
auto xTy = cast<RankedTensorType>(operands[0].getType());
Expand All @@ -82,21 +83,25 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(

auto encoding = xTy.getEncoding();
if (!encoding) {
return emitOptionalError(location, "expected an encoding");
return emitOptionalError(loc, "expected an encoding");
}
if (!mlir::isa<DotOperandEncodingAttr>(encoding)) {
return emitOptionalError(location, "expected an mma layout encoding");
}
if (xShape.size() < 2) {
return emitOptionalError(location, "tensor rank must be at least 2");
return emitOptionalError(loc, "expected a dotOperand encoding");
}

// For now we just return the input encoding. For fp4 we'll need to cast from
// tf32 to fp16 encoding and multiply the shape by two
assert((typeEncoded == F8F6F4Type::E4M3 || typeEncoded == F8F6F4Type::E5M2) &&
"NYI: only fp8e4m3 and fp8e5m2 are supported");
if (typeEncoded == F8F6F4Type::E2M1) {
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
auto newVEncoding = DotOperandEncodingAttr::get(
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
oldEncoding.getKWidth() * 2);
auto newShape = SmallVector<int64_t>(xShape);
newShape.back() *= 2;
inferredReturnTypes.push_back(
RankedTensorType::get(newShape, FloatType::getBF16(ctx), newVEncoding));
} else {
inferredReturnTypes.push_back(xTy);
}

inferredReturnTypes.push_back(xTy);
return success();
}

Expand Down
62 changes: 38 additions & 24 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ class ScaledBlockedToMMAv2
auto ctx = dotOp.getContext();

// Check that rhs scale is null
assert(dotOp.getRhsScale() == nullptr && "rhs scale must be null");
assert(dotOp.getRhsScale() == nullptr && "rhs scale NYI");

// operands
auto a = dotOp.getLhs();
Expand All @@ -426,10 +426,11 @@ class ScaledBlockedToMMAv2
}
};

assert(aType == F8F6F4Type::E4M3 ||
aType == F8F6F4Type::E5M2 && "lhs just supports fp8");
assert((aType == F8F6F4Type::E4M3 || aType == F8F6F4Type::E5M2 ||
aType == F8F6F4Type::E2M1) &&
"NYI: lhs supports fp4 or fp8");
assert(bType == F8F6F4Type::E4M3 ||
bType == F8F6F4Type::E5M2 && "rhs just supports fp8");
bType == F8F6F4Type::E5M2 && "NYI: rhs supports fp8");

// TODO run accelerate matmul on A and B first to choose their layouts
// Set return type
Expand All @@ -440,6 +441,7 @@ class ScaledBlockedToMMAv2
auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA,
rewriter.getBF16Type(), numWarps);
auto CTALayout = getCTALayout(oldRetType.getEncoding());
// TODO Use warpsPerTileV2
SmallVector<unsigned> warpsPerCTA = {numWarps, 1};
auto mmaEnc = NvidiaMmaEncodingAttr::get(ctx, /*versionMajor=*/versionMajor,
/*versionMinor=*/0, warpsPerCTA,
Expand All @@ -452,27 +454,39 @@ class ScaledBlockedToMMAv2
auto newAcc =
rewriter.create<ConvertLayoutOp>(oldAcc.getLoc(), newRetType, oldAcc);

auto toMMABf16 = [&newRetType, &rewriter, &ctx,
&enumToType](TypedValue<RankedTensorType> v, int idx,
F8F6F4Type type) {
// MMAv2 Layout
auto toMMABf16 = [&newRetType, &rewriter, &ctx, &enumToType](
TypedValue<RankedTensorType> v, int idx,
F8F6F4Type type) -> TypedValue<RankedTensorType> {
auto vType = v.getType();
auto newVEncoding = DotOperandEncodingAttr::get(
ctx, idx, newRetType.getEncoding(), enumToType((type)));
auto newVType = RankedTensorType::get(
v.getType().getShape(), v.getType().getElementType(), newVEncoding);
v = rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);

// Bitcast
auto vTypeFp8 = RankedTensorType::get(
vType.getShape(), rewriter.getFloat8E4M3FNType(), newVEncoding);
v = cast<TypedValue<RankedTensorType>>(
rewriter.create<BitcastOp>(v.getLoc(), vTypeFp8, v).getResult());

// Convert to bf16
auto vTypeBf16 = RankedTensorType::get(
vType.getShape(), rewriter.getBF16Type(), newVEncoding);
return rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, v);
if (type == F8F6F4Type::E2M1) {
// A bit too dynamically typed...
// perhaps return ints in both cases?

auto retEnc = dyn_cast<NvidiaMmaEncodingAttr>(newRetType.getEncoding());
auto newVEncoding = DotOperandEncodingAttr::get(
ctx, idx, newRetType.getEncoding(), /*kWidth=*/4);
auto newVType = RankedTensorType::get(
vType.getShape(), vType.getElementType(), newVEncoding);
return rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
} else {
assert(type == F8F6F4Type::E5M2 || type == F8F6F4Type::E4M3);
auto newVEncoding = DotOperandEncodingAttr::get(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: assert that this is a fp8 type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, although it's a bit redundant, as we are already asserting this at the beginning of the function and in semantics.py.

ctx, idx, newRetType.getEncoding(), /*kWidth=*/8);
auto newVType = RankedTensorType::get(
vType.getShape(), vType.getElementType(), newVEncoding);
v = rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);

// Bitcast
auto vTypeFp8 = RankedTensorType::get(vType.getShape(),
enumToType(type), newVEncoding);
v = cast<TypedValue<RankedTensorType>>(
rewriter.create<BitcastOp>(v.getLoc(), vTypeFp8, v).getResult());

// Convert to bf16
auto vTypeBf16 = RankedTensorType::get(
vType.getShape(), rewriter.getBF16Type(), newVEncoding);
return rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, v);
}
};
a = toMMABf16(a, 0, aType);
b = toMMABf16(b, 1, bType);
Expand Down
7 changes: 7 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ class TritonGPUReduceDataDuplicationPass
return;
if (!cvtNeedsSharedMemory(srcType, dstType))
return;
// FIXME [Dot LL]
// We support this one via LLs, as the LocalLoad path is buggy
bool largeKWidth =
dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64;
if (largeKWidth) {
return;
}
auto srcOrder = triton::gpu::getOrder(srcEncoding);
auto rank = srcOrder.size();
SmallVector<unsigned> sharedOrder;
Expand Down
Loading
Loading