diff --git a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h index 1124daec6dfc..1367f65a031f 100644 --- a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h +++ b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -250,13 +250,6 @@ chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef repShape, ArrayRef paddedRepShape, ArrayRef order, int swizzleByteSize); - -// FIXME -// Exposing to use it in LinearLayoutConversionsTest.cpp -// Remove it once we fully activate the DotOperand conversion via LLs -class DotOperandEncodingAttr; -LinearLayout ampereDotToLinearLayout(ArrayRef shape, - DotOperandEncodingAttr dot); } // namespace mlir::triton::gpu #endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index b44b7560163e..276a6e7004df 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -115,7 +115,12 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, assert(!isMfmaToDotShortcut(srcTy, dstTy)); - auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout); + // FIXME This is NOT entirely correct + // This should be getElemOrder, but we don't have such a method + // TODO Implement getElemOrder and make sure it's consistent with + // getContigPerThread + auto inOrd = gpu::getThreadOrder(srcLayout); + auto outOrd = gpu::getThreadOrder(dstLayout); scratchConfig.order = outOrd; unsigned srcContigPerThread = diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 893afc6590f0..71d587d0d92d 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -404,6 +404,22 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } return true; } + if (auto dotOperand = dyn_cast(layout)) { + if (auto nvidiaMma = + dyn_cast(dotOperand.getParent())) { + if (product(getCTAsPerCGA(nvidiaMma)) > 1) { + return false; + } + if (useLegacyMMAConversion) { + return false; + } + // FIXME [Dot LL] + // Enabling LL path for buggy kWidth path + bool largeKWidth = + dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64; + return largeKWidth && nvidiaMma.isAmpere(); + } + } if (isa(layout)) { return true; } @@ -460,6 +476,22 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } } + // FIXME [Dot LL] + // We know it's just for largeKWidth case in Ampere + // In this case, we need to pack the outputs into i32 + if (isa(dstTy.getEncoding())) { + auto concat = [&](Value a, Value b) { + return or_(zext(i32_ty, bitcast(a, i16_ty)), + shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16))); + }; + + SmallVector outVals32(outVals.size() / 2); + for (int i = 0; i < outVals32.size(); ++i) { + outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]); + } + outVals = outVals32; + } + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); rewriter.replaceOp(op, result); diff --git a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp index 1346cc143ed2..00d840d7ccf7 100644 --- a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -90,6 +90,16 @@ void decomposeBlockedToDotLayoutConversion(ModuleOp module) { auto dstDotOp = dyn_cast(dstType.getEncoding()); if (srcBlocked && dstDotOp) { + // FIXME [Dot LL] + // We support this one via LLs, as the LocalLoad path is buggy + if (auto mma = dyn_cast(dstDotOp.getParent())) { + bool largeKWidth = + dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64; + if (mma.isAmpere() && largeKWidth) { + return; + } + } + Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); auto tmpType = MemDescType::get( diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 5d1d3617b08a..8179c1cda1d7 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -8,6 +8,7 @@ #include "mlir/Support/LLVM.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" @@ -234,8 +235,31 @@ static SmallVector eraseOrder(ArrayRef order, return resOrder; } +SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, + bool kMajor) { + // kMajor: if true, the matrix is fastest-running on k, + // otherwise it is on m (resp. n) + // opIdx=0: [batch, m, k] if rank == 3 else [m, k] + // opIdx=1: [batch, k, n] if rank == 3 else [k, n] + // batch (if rank == 3) is always the slowest running dimension + assert(rank == 2 || rank == 3); + assert(opIdx == 0 || opIdx == 1); + SmallVector order(rank); + std::iota(order.rbegin(), order.rend(), 0); + // If opIdx is 1 and kMajor is true, the order is [0, 1] + // (resp. [1, 2, 0] if rank == 3) + // Same if opIdx is 0 and kMajor is false + if (bool(opIdx) == kMajor) { + std::swap(order[0], order[1]); + } + return order; +} + SmallVector getWarpOrder(Attribute layout) { auto order = getOrder(layout); + // FIXME: This mmaLayout if should just return + // getOrderForDotOperand(0, order.size(), kMajor=false) + // as mma has the same order as DotOperand(opIdx=0) if (auto mmaLayout = dyn_cast(layout)) { if (mmaLayout.isHopper()) { // Hopper MMA instructions force a warp order of [0, 1]. See docs: @@ -245,40 +269,8 @@ SmallVector getWarpOrder(Attribute layout) { order.insert(order.begin(), 0); } } else if (auto dotOpLayout = dyn_cast(layout)) { - // opIdx=0: [/*dim0*/batch, /*dim1=*/m, /*dim2=*/k] -> order=[1, 2, 0] - // opIdx=1: [/*dim0*/batch, /*dim1=*/k, /*dim2=*/n] -> order=[2, 1, 0] - std::iota(order.rbegin(), order.rend(), 0); - if (dotOpLayout.getOpIdx() == 0) { - std::swap(order[0], order[1]); - } - } - return order; -} - -SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank) { - assert((rank == 2 || rank == 3) && - "Invalid rank for dot operand order computation"); - SmallVector order(rank); - // The 'order' field typically represents a descending sorted array of - // dimensions based on contiguity. For instance, in axisInfo utilities that - // retrieve tensor contiguity, it's assumed that the dimension with the - // highest contiguity corresponds to order[0]. - // - // The relation between contiguity and order is only relevant if the layout - // interfaces with HBM, as is the case when we load tensor from HBM to - // registers in the dot layout to bypass LDS. When bypassing LDS, we make - // the following assumptions about tensor layouts: - // - Tensor A (opIdx == 0) is considered to be row-major. - // - Tensor B (opIdx == 1) is considered to be column-major. - // - // Based on these assumptions, we define the following orders: - // - For opIdx == 0, batch=dim0, m=dim1, and k=dim2, we assume an order of [2, - // 1, 0] for 3D tensors. - // - For opIdx == 1, batch=dim0, k=dim1, and n=dim2, we assume an order of [1, - // 2, 0] for 3D tensors. - std::iota(order.rbegin(), order.rend(), 0); - if (opIdx == 1) { - std::swap(order[0], order[1]); + order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(), + /*kMajor*/ false); } return order; } @@ -295,8 +287,8 @@ SmallVector getOrder(Attribute layout) { return order; } if (auto dotLayout = dyn_cast(layout)) { - auto rank = getWarpsPerCTA(dotLayout.getParent()).size(); - return getOrderForDotOperand(dotLayout.getOpIdx(), rank); + auto rank = dotLayout.getWarpsPerCTA().size(); + return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true); } if (auto sliceLayout = dyn_cast(layout)) { SmallVector parentOrder = getOrder(sliceLayout.getParent()); @@ -1048,7 +1040,8 @@ SmallVector DotOperandEncodingAttr::getWarpOrder() const { return ::getWarpOrder(*this); } SmallVector DotOperandEncodingAttr::getThreadOrder() const { - return ::getOrder(*this); + return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), + /*kMajor*/ true); } SmallVector DotOperandEncodingAttr::getShapePerCTATile( ArrayRef tensorShape) const { @@ -2019,6 +2012,7 @@ SmallVector NvidiaMmaEncodingAttr::getMMAv2RepForOperand( ArrayRef shape, int bitwidth, int kWidth, int opIdx) const { auto rank = shape.size(); auto warpsPerCTA = getWarpsPerCTA(); + SmallVector shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth}; int numRepBatch = rank == 3 diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index bc365057f811..039de22cb4a5 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -886,13 +886,14 @@ std::optional DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { if (auto mfmaLayout = llvm::dyn_cast(getParent())) { return dotOperandMfmaToLinearLayout(*this, shape); + } else if (auto mma = mlir::dyn_cast(getParent())) { + // FIXME [Dot LL] + // Do this unconditionally + auto largeKWidth = getKWidth() == 8; + if (mma.isAmpere() && largeKWidth) { + return ampereDotToLinearLayout(shape, *this); + } } - // TODO Activate in a follow-up PR - // else if (auto mma = mlir::dyn_cast(getParent())) { - // if (mma.isAmpere()) { - // return ampereDotToLinearLayout(shape, *this); - // } - //} return std::nullopt; } diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index b9f3d3040dd3..e61fe096e10b 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -17,8 +17,9 @@ LogicalResult UpcastMXFPOp::verify() { auto xTy = getSrc().getType(); auto scaleTy = getScale().getType(); - if (xTy.getElementType() != FloatType::getBF16(getContext())) { - return emitOpError("element type of the first operand must be bf16"); + if (xTy.getElementType() != FloatType::getBF16(getContext()) && + xTy.getElementType() != IntegerType::get(getContext(), 8)) { + return emitOpError("element type of the first operand must be bf16 or i8"); } if (scaleTy.getElementType() != IntegerType::get(getContext(), 8)) { @@ -72,7 +73,7 @@ LogicalResult UpcastMXFPOp::verify() { } LogicalResult UpcastMXFPOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, + MLIRContext *ctx, std::optional loc, ValueRange operands, DictionaryAttr attributes, OpaqueProperties opaqueProperties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { auto xTy = cast(operands[0].getType()); @@ -82,21 +83,25 @@ LogicalResult UpcastMXFPOp::inferReturnTypes( auto encoding = xTy.getEncoding(); if (!encoding) { - return emitOptionalError(location, "expected an encoding"); + return emitOptionalError(loc, "expected an encoding"); } if (!mlir::isa(encoding)) { - return emitOptionalError(location, "expected an mma layout encoding"); - } - if (xShape.size() < 2) { - return emitOptionalError(location, "tensor rank must be at least 2"); + return emitOptionalError(loc, "expected a dotOperand encoding"); } - // For now we just return the input encoding. For fp4 we'll need to cast from - // tf32 to fp16 encoding and multiply the shape by two - assert((typeEncoded == F8F6F4Type::E4M3 || typeEncoded == F8F6F4Type::E5M2) && - "NYI: only fp8e4m3 and fp8e5m2 are supported"); + if (typeEncoded == F8F6F4Type::E2M1) { + auto oldEncoding = cast(encoding); + auto newVEncoding = DotOperandEncodingAttr::get( + ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(), + oldEncoding.getKWidth() * 2); + auto newShape = SmallVector(xShape); + newShape.back() *= 2; + inferredReturnTypes.push_back( + RankedTensorType::get(newShape, FloatType::getBF16(ctx), newVEncoding)); + } else { + inferredReturnTypes.push_back(xTy); + } - inferredReturnTypes.push_back(xTy); return success(); } diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 08a88ae397a7..a2d4012bf23e 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -406,7 +406,7 @@ class ScaledBlockedToMMAv2 auto ctx = dotOp.getContext(); // Check that rhs scale is null - assert(dotOp.getRhsScale() == nullptr && "rhs scale must be null"); + assert(dotOp.getRhsScale() == nullptr && "rhs scale NYI"); // operands auto a = dotOp.getLhs(); @@ -426,10 +426,11 @@ class ScaledBlockedToMMAv2 } }; - assert(aType == F8F6F4Type::E4M3 || - aType == F8F6F4Type::E5M2 && "lhs just supports fp8"); + assert((aType == F8F6F4Type::E4M3 || aType == F8F6F4Type::E5M2 || + aType == F8F6F4Type::E2M1) && + "NYI: lhs supports fp4 or fp8"); assert(bType == F8F6F4Type::E4M3 || - bType == F8F6F4Type::E5M2 && "rhs just supports fp8"); + bType == F8F6F4Type::E5M2 && "NYI: rhs supports fp8"); // TODO run accelerate matmul on A and B first to choose their layouts // Set return type @@ -440,6 +441,7 @@ class ScaledBlockedToMMAv2 auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA, rewriter.getBF16Type(), numWarps); auto CTALayout = getCTALayout(oldRetType.getEncoding()); + // TODO Use warpsPerTileV2 SmallVector warpsPerCTA = {numWarps, 1}; auto mmaEnc = NvidiaMmaEncodingAttr::get(ctx, /*versionMajor=*/versionMajor, /*versionMinor=*/0, warpsPerCTA, @@ -452,27 +454,39 @@ class ScaledBlockedToMMAv2 auto newAcc = rewriter.create(oldAcc.getLoc(), newRetType, oldAcc); - auto toMMABf16 = [&newRetType, &rewriter, &ctx, - &enumToType](TypedValue v, int idx, - F8F6F4Type type) { - // MMAv2 Layout + auto toMMABf16 = [&newRetType, &rewriter, &ctx, &enumToType]( + TypedValue v, int idx, + F8F6F4Type type) -> TypedValue { auto vType = v.getType(); - auto newVEncoding = DotOperandEncodingAttr::get( - ctx, idx, newRetType.getEncoding(), enumToType((type))); - auto newVType = RankedTensorType::get( - v.getType().getShape(), v.getType().getElementType(), newVEncoding); - v = rewriter.create(v.getLoc(), newVType, v); - - // Bitcast - auto vTypeFp8 = RankedTensorType::get( - vType.getShape(), rewriter.getFloat8E4M3FNType(), newVEncoding); - v = cast>( - rewriter.create(v.getLoc(), vTypeFp8, v).getResult()); - - // Convert to bf16 - auto vTypeBf16 = RankedTensorType::get( - vType.getShape(), rewriter.getBF16Type(), newVEncoding); - return rewriter.create(v.getLoc(), vTypeBf16, v); + if (type == F8F6F4Type::E2M1) { + // A bit too dynamically typed... + // perhaps return ints in both cases? + + auto retEnc = dyn_cast(newRetType.getEncoding()); + auto newVEncoding = DotOperandEncodingAttr::get( + ctx, idx, newRetType.getEncoding(), /*kWidth=*/4); + auto newVType = RankedTensorType::get( + vType.getShape(), vType.getElementType(), newVEncoding); + return rewriter.create(v.getLoc(), newVType, v); + } else { + assert(type == F8F6F4Type::E5M2 || type == F8F6F4Type::E4M3); + auto newVEncoding = DotOperandEncodingAttr::get( + ctx, idx, newRetType.getEncoding(), /*kWidth=*/8); + auto newVType = RankedTensorType::get( + vType.getShape(), vType.getElementType(), newVEncoding); + v = rewriter.create(v.getLoc(), newVType, v); + + // Bitcast + auto vTypeFp8 = RankedTensorType::get(vType.getShape(), + enumToType(type), newVEncoding); + v = cast>( + rewriter.create(v.getLoc(), vTypeFp8, v).getResult()); + + // Convert to bf16 + auto vTypeBf16 = RankedTensorType::get( + vType.getShape(), rewriter.getBF16Type(), newVEncoding); + return rewriter.create(v.getLoc(), vTypeBf16, v); + } }; a = toMMABf16(a, 0, aType); b = toMMABf16(b, 1, bType); diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp index b1e296c1bbe4..3a406c3cc28e 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -44,6 +44,13 @@ class TritonGPUReduceDataDuplicationPass return; if (!cvtNeedsSharedMemory(srcType, dstType)) return; + // FIXME [Dot LL] + // We support this one via LLs, as the LocalLoad path is buggy + bool largeKWidth = + dstDotOp.getKWidth() * dstType.getElementTypeBitWidth() > 64; + if (largeKWidth) { + return; + } auto srcOrder = triton::gpu::getOrder(srcEncoding); auto rank = srcOrder.size(); SmallVector sharedOrder; diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a0b149099dc1..4780acc53bbb 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3319,10 +3319,10 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid (M, N, K, col_a, col_b, type_a, type_b, 4) for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128]) for col_a, col_b in itertools.product([True, False], repeat=2) - # We don't test e5m2 as it seems to overflow more easily + # We don't test e5m2 as its range + the uniform sampling overflows easily # Tested locally and it works fine other than for ~10 entries out of 10_000 # which are of the size of 10**30 - for type_a in ["e4m3"] + for type_a in ["e2m1", "e4m3"] for type_b in ["e4m3"] ]) def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, device): @@ -3427,10 +3427,9 @@ def mxfp_to_bf16_kernel( tl.store(mxfp_ptr + offsets, tl.ravel(mxfp), mask=offsets < N * 32) def dot_scale_ref(x, scale, y, type_x, type_y): - e_bits, m_bits = {"e4m3": (4, 3), "e5m2": (5, 2)}[type_x] + e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type_x] type_fp8_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2}[type_y] - # Need to implement fp4 -> fp8 cast to support 1 byte types comp_dtype = torch.bfloat16 out_dtype = torch.float32 @@ -3447,11 +3446,17 @@ def dot_scale_ref(x, scale, y, type_x, type_y): torch.manual_seed(0) - def create_uint8(shape): - return torch.randint(0xff, shape, dtype=torch.uint8, device=device) - - x = create_uint8((K, M)).T if col_a else create_uint8((M, K)) - y = create_uint8((N, K)).T if col_b else create_uint8((K, N)) + def create_uint8(shape, col_major=False): + if col_major: + shape = shape[:-2] + (shape[-1], shape[-2]) + ret = torch.randint(1 << 8, shape, dtype=torch.uint8, device=device) + if col_major: + ret = ret.mT + return ret + + DIV_FACTOR = 2 if type_a == "e2m1" else 1 + x = create_uint8((M, K // DIV_FACTOR), col_major=col_a) + y = create_uint8((K, N), col_major=col_b) scale_x = create_uint8((M, K // 32)) z = x.new_empty((M, N), dtype=torch.float32) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 1fdfbadcd290..be157c5b4609 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1548,9 +1548,18 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, lhs_rank = len(lhs.shape) rhs_rank = len(rhs.shape) assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})" - M, K = lhs.type.shape[-2:] - N = rhs.type.shape[-1] - assert K == rhs.type.shape[-2], f"Reduction dimension should agree; (lhs: {lhs.shape} vs rhs: {rhs.shape})" + lhs_format_enum = _str_to_fp_type(lhs_format) + rhs_format_enum = _str_to_fp_type(rhs_format) + assert lhs_format in ("e2m1", "e4m3", "e5m2"), f"NYI: lhs_format {lhs_format}" + assert rhs_format in ("e4m3", "e5m2"), f"NYI: rhs_format {rhs_format}" + rhs_scale_is_none = isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None + assert rhs_scale_is_none, "NYI: rhs_scale not supported" + + M = lhs.type.shape[-2] + K, N = rhs.type.shape[-2:] + PACKED = 2 if lhs_format == "e2m1" else 1 + assert K == PACKED * lhs.type.shape[ + -1], f"Reduction dimension should pack the same number of elements; (lhs: {lhs.shape} vs rhs: {rhs.shape})" assert K >= 64, f"scaled_dot NYI for K < 64. Got {K=}" B = lhs.type.shape[0] if lhs_rank == 3 else None @@ -1561,9 +1570,7 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, else: acc_handle = acc.handle assert acc.type == ret_ty - lhs_format_enum = _str_to_fp_type(lhs_format) - rhs_format_enum = _str_to_fp_type(rhs_format) - rhs_scale_handle = None if isinstance(rhs_scale, tl.constexpr) else rhs_scale.handle + rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle return tl.tensor( builder.create_dot_scaled(lhs.handle, lhs_scale.handle, lhs_format_enum, rhs.handle, rhs_scale_handle, rhs_format_enum, acc_handle), ret_ty) diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index a63dd4ff885e..a0719c974f9c 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -39,7 +39,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK: offset = 0, size = 4608 %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other : tensor<32x128x!tt.ptr, #BL> - // CHECK-NEXT: offset = 0, size = 4224 + // CHECK-NEXT: offset = 0, size = 4352 %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> %c = tt.dot %a, %b, %prev_c : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> @@ -67,14 +67,14 @@ tt.func @reusable(%A : !tt.ptr) { // CHECK-NEXT: offset = 0, size = 4608 %a1 = triton_gpu.convert_layout %a1_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %a2_ = tt.load %b_ptr, %cst3, %cst4 : tensor<32x128x!tt.ptr, #AL> - // CHECK-NEXT: offset = 0, size = 1152 + // CHECK-NEXT: offset = 0, size = 1088 %a2 = triton_gpu.convert_layout %a2_ : tensor<32x128xf16, #AL> -> tensor<32x128xf16, #B_DOT> %a3_ = tt.load %a_ptr, %cst1, %cst2 : tensor<128x32x!tt.ptr, #AL> // CHECK-NEXT: offset = 0, size = 4608 %a3 = triton_gpu.convert_layout %a3_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> %c = tt.dot %a1, %a2, %c_init : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %a4_ = tt.load %b_ptr, %cst3, %cst4 : tensor<32x128x!tt.ptr, #AL> - // CHECK-NEXT: offset = 0, size = 1152 + // CHECK-NEXT: offset = 0, size = 1088 %a4 = triton_gpu.convert_layout %a4_ : tensor<32x128xf16, #AL> -> tensor<32x128xf16, #B_DOT> %c1 = tt.dot %a3, %a4, %c : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> tt.return diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index 4a3f530a747d..79ccb57206ae 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -1,6 +1,9 @@ #include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h" #include "Utility.h" #include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "llvm/ADT/SmallVector.h" using namespace mlir; using namespace mlir::triton; @@ -58,10 +61,53 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( const LLVMTypeConverter *typeConverter, Location loc, ConversionPatternRewriter &rewriter, Value value, int batch, int n0, int n1, RankedTensorType type) { - auto elems = unpackLLElements(loc, value, rewriter); int offset{}; ValueTableV2 vals; + + // FIXME [Dot LL] + // [ez] Generalize the logic below for kWidth * elemBitWidth > 32 + auto dot = cast(type.getEncoding()); + auto largeK = dot.getKWidth() == 8 && + cast(dot.getParent()).isAmpere(); + if (largeK) { + llvm::SmallVector si; + + // For kWidth = 8, split the mma into 4 mmas with "stride 4" along K + if (dot.getOpIdx() == 0) { + si = llvm::SmallVector{0, 8, 4, 12, 1, 9, 5, 13, + 2, 10, 6, 14, 3, 11, 7, 15}; + } else { + si = llvm::SmallVector{0, 4, 1, 5, 2, 6, 3, 7}; + } + + auto step = si.size(); + SmallVector perm(step); + for (auto i = 0; i < elems.size() / step; ++i) { + for (auto j = 0; j < step; ++j) { + perm[j] = elems[i * step + si[j]]; + } + std::copy(perm.begin(), perm.end(), elems.begin() + i * step); + } + + if (dot.getOpIdx() == 1) { + // there are kWidth * 2 elems packed as bf16x2 + int elemsInTile = dot.getKWidth(); + // n0 and n1 are unrolled in the legacy path + // Unrolling n1 makes some sense, but unrolling n0 makes absolutely no + // sense IMO + n0 *= 2; + n1 *= 2; + for (auto b = 0; b < batch; ++b) + for (auto j = 0; j < n1 / elemsInTile; ++j) + for (auto i = 0; i < n0; ++i) + for (auto k = 0; k < elemsInTile; ++k) { + vals[{b, i, elemsInTile * j + k}] = elems[offset++]; + } + return vals; + } + } + for (auto b = 0; b < batch; ++b) for (auto i = 0; i < n0; ++i) { for (auto j = 0; j < n1; j++) { @@ -330,9 +376,13 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter, int repM = repA[1], repN = repB[2], repK = repA[2]; int repBatch = repA[0]; - // shape / shape_per_cta auto ha = getValuesFromDotOperandLayoutStruct( typeConverter, loc, rewriter, loadedA, repBatch, repM, repK, aTensorTy); + + // FIXME [Dot LL] + // max(repN / 2, 1) is wrong for repN = 1! + // This is also wrong in + // NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand auto hb = getValuesFromDotOperandLayoutStruct( typeConverter, loc, rewriter, loadedB, repBatch, std::max(repN / 2, 1), repK, bTensorTy); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp index aeca44bb46ce..9404bb4474d0 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -1,4 +1,5 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" @@ -8,8 +9,10 @@ #include "mlir/IR/ValueRange.h" #include "mlir/Transforms/DialectConversion.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" #include using namespace mlir; @@ -27,6 +30,73 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {} + llvm::SmallVector + unpackFP4Elements(Location loc, ConversionPatternRewriter &rewriter, + const llvm::SmallVector &vals, Value laneId) const { + auto fp4x2ToBf16x2 = [&loc, &rewriter](Value v) -> Value { + auto em0 = and_(v, i8_val(0x70)); + auto em1 = and_(v, i8_val(0x7)); + Value v0 = or_(shl(zext(i16_ty, em0), i16_val(2)), + shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8))); + Value v1 = or_(shl(zext(i16_ty, em1), i16_val(6)), + shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12))); + + // Three cases: + // 1) x is normal and non-zero: Correct bias + v0 = select(icmp_ne(and_(em0, i8_val(0x60)), i8_val(0)), + add(v0, i16_val((127 - 1) << 7)), v0); + v1 = select(icmp_ne(and_(em1, i8_val(0x6)), i8_val(0)), + add(v1, i16_val((127 - 1) << 7)), v1); + + // 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in + // bf16 + v0 = select(icmp_eq(em0, i8_val(0x10)), + or_(i16_val(16128), and_(v0, i16_val(0x8000))), v0); + v1 = select(icmp_eq(em1, i8_val(0x1)), + or_(i16_val(16128), and_(v1, i16_val(0x8000))), v1); + // 3) x is zero, nothing to do + + // Swap as they come packed in big endian + return or_(zext(i32_ty, v0), shl(zext(i32_ty, v1), i32_val(16))); + }; + + auto fp4x8ToBf16x2 = [&loc, &rewriter, &fp4x2ToBf16x2]( + Value v) -> llvm::SmallVector { + llvm::SmallVector results(4); + for (int i = 0; i < 4; ++i) { + auto v_i = trunc(i8_ty, lshr(v, i32_val(8 * i))); + results[i] = fp4x2ToBf16x2(v_i); + } + return results; + }; + + // Split fp4x8 into 4 bf16x2 + llvm::SmallVector ret; + ret.reserve(vals.size() * 4); + for (int i = 0; i < vals.size(); ++i) { + auto vs = fp4x8ToBf16x2(vals[i]); + assert(vs.size() == 4); + for (auto v : vs) { + ret.push_back(v); + } + } + // FIXME [Dot LL] + // The DotOperandEncodingAttr without LLs encodes the + // layout as + // e0 e1 + // e2 e3 + // rather than transposed that, as the PTX docs say + // We transpose every block of 4 elements (kWidth = 8 -> 4 bf16x2) + assert(ret.size() % 16 == 0); + for (int i = 0; i < ret.size() / 16; ++i) { + for (int j = 0; j < 4; ++j) { + std::swap(ret[16 * i + j + 4], ret[16 * i + j + 8]); + } + } + + return ret; + } + LogicalResult matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -37,6 +107,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { auto xVals = unpackLLElements(loc, operands[0], rewriter); auto scaleVals = unpackLLElements(loc, operands[1], rewriter); + auto fpType = op.getFpType(); Value tid = tid_val(); auto mod = op->getParentOfType(); @@ -45,7 +116,11 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { Value warpId = udiv(tid, warpSize); Value laneId = urem(tid, warpSize); - auto scale = [&loc, &rewriter](Value v, Value s) -> Value { + if (fpType == F8F6F4Type::E2M1) { + xVals = unpackFP4Elements(loc, rewriter, xVals, laneId); + } + + auto scaleBf16x2 = [&loc, &rewriter](Value v, Value s) -> Value { // Split bf16x2 into 2 bf16, scale each of them, and pack them back // TODO Is it true that the bfloats are always packed as bf16x2? auto bf16_0 = bitcast(trunc(i16_ty, v), bf16_ty); @@ -69,18 +144,18 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { auto c = mul(udiv(laneId, i32_val(4)), i32_val(2)); std::array ci = {c, add(c, i32_val(1)), add(c, i32_val(16)), add(c, i32_val(17))}; + for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { + // column major as per the DotOperandEncoding(opidx=0) layout auto si = std::array{ targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[0]), - targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[1]), targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[2]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[1]), targetInfo.shuffleIdx(rewriter, loc, scaleVal, ci[3]), }; - // si indices for the 16 elements in x - std::array siMap = {0, 0, 2, 2, 0, 0, 2, 2, - 1, 1, 3, 3, 1, 1, 3, 3}; + for (int j = 0; j < 16; ++j) { - xVals[16 * i + j] = scale(xVals[16 * i + j], si[siMap[j]]); + xVals[16 * i + j] = scaleBf16x2(xVals[16 * i + j], si[j / 4]); } } diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index 015a450dfff0..76c9c442257d 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -502,7 +502,7 @@ TEST_F(LinearLayoutConversionsTest, MMAv3_4x4Warps) { } TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { - EXPECT_EQ(ampereDotToLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({16, 64}, dotMMAv2(0, 8, {1, 1}, {1, 0})), LinearLayout( { {S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, @@ -511,7 +511,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(ampereDotToLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({64, 8}, dotMMAv2(1, 8, {1, 1}, {1, 0})), LinearLayout( { {S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, @@ -524,7 +524,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_tile_kwidth8) { TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { EXPECT_EQ( - ampereDotToLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})), + toLinearLayout({128, 128}, dotMMAv2(0, 8, {4, 1}, {1, 0})), LinearLayout( { {S("register"), @@ -534,7 +534,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(ampereDotToLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({128, 64}, dotMMAv2(1, 8, {4, 1}, {1, 0})), LinearLayout( { {S("register"), @@ -554,7 +554,7 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); - EXPECT_EQ(ampereDotToLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})), + EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(1, 8, {4, 1}, {1, 0})), LinearLayout( { {S("register"),