Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Triton] Fix a bug while lowering to LLVM for block_k=16 when the input types involve an 8-bit. This change is porting in this [PR](https://github.com/triton-lang/triton/pull/4768). #19225

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions third_party/triton/temporary/block_k_16_fix.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
--- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
+++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
@@ -46,6 +46,7 @@ SmallVector<Value> reorderValues(const S
return values;
auto inEncoding = dyn_cast<DotOperandEncodingAttr>(inTensorTy.getEncoding());
auto ouEncoding = dyn_cast<DotOperandEncodingAttr>(ouTensorTy.getEncoding());
+ auto in_shape = inTensorTy.getShape();
assert(inEncoding == ouEncoding);
if (!inEncoding)
return values;
@@ -101,6 +102,11 @@ SmallVector<Value> reorderValues(const S
//
// [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15]
SmallVector<Value> ret;
+
+ // In the corner cases (1) where in_shape[0] == 16 and getOpIdx() ==
+ // 1, and (2) where in_shape[1] == 16 and getOpIdx == 0, extra elements will
+ // be loaded. It is necessary to discard these additional elements.
+ bool loadsExtraElements = in_shape[1 - inEncoding.getOpIdx()] == 16;
for (unsigned i = 0; i < values.size(); i += 16) {
ret.push_back(values[i]);
ret.push_back(values[i + 1]);
@@ -110,6 +116,8 @@ SmallVector<Value> reorderValues(const S
ret.push_back(values[i + 3]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 7]);
+ if (loadsExtraElements)
+ continue;
ret.push_back(values[i + 8]);
ret.push_back(values[i + 9]);
ret.push_back(values[i + 12]);
new file mode 100644
--- /dev/null
+++ b/test/Conversion/tritongpu_to_llvm_ampere.mlir
@@ -0,0 +1,23 @@
+// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=80 2>&1 | FileCheck %s
+
+#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
+module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 3072 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
+ tt.func public @ampere_s8_to_fp16_conversion_opIdx1(%1 : tensor<16x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>) attributes {noinline = false} {
+ // CHECK-LABEL: ampere_s8_to_fp16_conversion_opIdx1
+ // CHECK: llvm.sitofp %{{.*}} : i8 to f16
+ %2 = arith.sitofp %1 : tensor<16x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> to tensor<16x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
+ tt.return
+ }
+}
+
+// -----
+
+#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
+module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 3072 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
+ tt.func public @ampere_s8_to_fp16_conversion_opIdx0(%1 : tensor<32x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>) attributes {noinline = false} {
+ // CHECK-LABEL: @ampere_s8_to_fp16_conversion_opIdx0
+ // CHECK: llvm.sitofp %{{.*}} : i8 to f16
+ %2 = arith.sitofp %1 : tensor<32x16xi8, #triton_gpu.dot_op<{opIdx = 0 , parent = #mma, kWidth = 4}>> to tensor<32x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
+ tt.return
+ }
+}
1 change: 1 addition & 0 deletions third_party/triton/temporary/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ those to this list.

temporary_patch_list = [
"//third_party/triton:temporary/replace_unreachable_by_abort.patch",
"//third_party/triton:temporary/block_k_16_fix.patch",
# Add new patches just above this line
]
3 changes: 3 additions & 0 deletions xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,9 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {
// Retrieve the minimum bit-width participating in the dot. This is needed
// to avoid autotuning configurations that are not supported by Triton. This
// is used to restrict the values for tile_k.
// TODO(b/378449587): This assumes a convert exists which doesn't cover all
// cases. For example, a bf16 dot(fp8, fp8) will not be handled as the minimum
// bit-width will be 8 but that will not be captured here.
std::vector<const HloInstruction*> converts =
HloBfsFindAll({&dot}, [&](const HloInstruction* node) {
return node->opcode() == HloOpcode::kConvert;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,33 @@ class TritonGemmTestWithoutTritonGemmAny : public TritonGemmTest {
}
};

TEST_F(TritonGemmTest, FP8DotSmallTileDoesNotCrash) {
if (!GetCudaComputeCapability().IsAtLeastHopper()) {
GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs.";
}

constexpr std::string_view kHloText = R"(
HloModule m

triton_dot {
%parameter_0 = f8e4m3fn[32,32]{1,0} parameter(0)
%parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1)
ROOT %dot.1643 = bf16[32,32]{1,0} dot(f8e4m3fn[32,32]{1,0} %parameter_0, f8e4m3fn[32,32]{0,1} %parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}

ENTRY e {
p0 = f8e4m3fn[32,32]{1,0} parameter(0)
p1 = f8e4m3fn[32,32]{1,0} parameter(1)
ROOT _ = bf16[32,32] fusion(p0, p1), kind=kCustom, calls=triton_dot,
backend_config={"fusion_backend_config": {kind: "__triton_gemm",
triton_gemm_config: {"block_m":16,"block_n":16,"block_k":16,
"split_k":1,"num_stages":2,"num_warps":2,
"num_ctas":1}}}
})";

EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false));
}

TEST_F(TritonGemmTest, RejectDotInt4HLO) {
constexpr std::string_view kHloText = R"(
HloModule t
Expand Down
Loading