Skip to content

Commit

Permalink
[Triton] Fix a bug while lowering to LLVM for block_k=16 when the inp…
Browse files Browse the repository at this point in the history
…ut types involve an 8-bit. This change is porting in this [PR](triton-lang/triton#4768).

PiperOrigin-RevId: 695275077
  • Loading branch information
Moerafaat authored and Google-ML-Automation committed Nov 11, 2024
1 parent 3c2a6e7 commit 4a4d7af
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 0 deletions.
58 changes: 58 additions & 0 deletions third_party/triton/temporary/block_k_16_fix.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
--- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
+++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
@@ -46,6 +46,7 @@ SmallVector<Value> reorderValues(const S
return values;
auto inEncoding = dyn_cast<DotOperandEncodingAttr>(inTensorTy.getEncoding());
auto ouEncoding = dyn_cast<DotOperandEncodingAttr>(ouTensorTy.getEncoding());
+ auto in_shape = inTensorTy.getShape();
assert(inEncoding == ouEncoding);
if (!inEncoding)
return values;
@@ -101,6 +102,11 @@ SmallVector<Value> reorderValues(const S
//
// [0, 1], [4, 5], [2, 3], [6, 7], [8, 9], [12, 13], [10, 11], [14, 15]
SmallVector<Value> ret;
+
+ // In the corner cases (1) where in_shape[0] == 16 and getOpIdx() ==
+ // 1, and (2) where in_shape[1] == 16 and getOpIdx == 0, extra elements will
+ // be loaded. It is necessary to discard these additional elements.
+ bool loadsExtraElements = in_shape[1 - inEncoding.getOpIdx()] == 16;
for (unsigned i = 0; i < values.size(); i += 16) {
ret.push_back(values[i]);
ret.push_back(values[i + 1]);
@@ -110,6 +116,8 @@ SmallVector<Value> reorderValues(const S
ret.push_back(values[i + 3]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 7]);
+ if (loadsExtraElements)
+ continue;
ret.push_back(values[i + 8]);
ret.push_back(values[i + 9]);
ret.push_back(values[i + 12]);
new file mode 100644
--- /dev/null
+++ b/test/Conversion/tritongpu_to_llvm_ampere.mlir
@@ -0,0 +1,23 @@
+// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=80 2>&1 | FileCheck %s
+
+#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
+module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 3072 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
+ tt.func public @ampere_s8_to_fp16_conversion_opIdx1(%1 : tensor<16x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>) attributes {noinline = false} {
+ // CHECK-LABEL: ampere_s8_to_fp16_conversion_opIdx1
+ // CHECK: llvm.sitofp %{{.*}} : i8 to f16
+ %2 = arith.sitofp %1 : tensor<16x32xi8, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> to tensor<16x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
+ tt.return
+ }
+}
+
+// -----
+
+#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
+module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 3072 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
+ tt.func public @ampere_s8_to_fp16_conversion_opIdx0(%1 : tensor<32x16xi8, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>) attributes {noinline = false} {
+ // CHECK-LABEL: @ampere_s8_to_fp16_conversion_opIdx0
+ // CHECK: llvm.sitofp %{{.*}} : i8 to f16
+ %2 = arith.sitofp %1 : tensor<32x16xi8, #triton_gpu.dot_op<{opIdx = 0 , parent = #mma, kWidth = 4}>> to tensor<32x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
+ tt.return
+ }
+}
1 change: 1 addition & 0 deletions third_party/triton/temporary/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ those to this list.

temporary_patch_list = [
"//third_party/triton:temporary/replace_unreachable_by_abort.patch",
"//third_party/triton:temporary/block_k_16_fix.patch",
# Add new patches just above this line
]
3 changes: 3 additions & 0 deletions xla/service/gpu/autotuning/gemm_fusion_autotuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,9 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {
// Retrieve the minimum bit-width participating in the dot. This is needed
// to avoid autotuning configurations that are not supported by Triton. This
// is used to restrict the values for tile_k.
// TODO(b/378449587): This assumes a convert exists which doesn't cover all
// cases. For example, a bf16 dot(fp8, fp8) will not be handled as the minimum
// bit-width will be 8 but that will not be captured here.
std::vector<const HloInstruction*> converts =
HloBfsFindAll({&dot}, [&](const HloInstruction* node) {
return node->opcode() == HloOpcode::kConvert;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,26 @@ class TritonGemmTestWithoutTritonGemmAny : public TritonGemmTest {
}
};

TEST_F(TritonGemmTest, FP8DotDoesNotCrash) {
constexpr std::string_view kHloText = R"(
HloModule m
triton_dot {
%parameter_0 = f8e4m3fn[32,32]{1,0} parameter(0)
%parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1)
ROOT %dot.1643 = bf16[32,32]{1,0} dot(f8e4m3fn[32,32]{1,0} %parameter_0, f8e4m3fn[32,32]{0,1} %parameter_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
ENTRY e {
p0 = f8e4m3fn[32,32]{1,0} parameter(0)
p1 = f8e4m3fn[32,32]{1,0} parameter(1)
ROOT _ = bf16[32,32] fusion(p0, p1), kind=kCustom, calls=triton_dot,
backend_config="{\"fusion_backend_config\": {kind: \"__triton_gemm\", triton_gemm_config: {\"block_m\":\"16\",\"block_n\":\"16\",\"block_k\":\"16\",\"split_k\":\"1\",\"num_stages\":\"2\",\"num_warps\":\"2\",\"num_ctas\":\"1\"}}}"
})";

EXPECT_TRUE(Run(kHloText, /*run_hlo_passes=*/false));
}

TEST_F(TritonGemmTest, RejectDotInt4HLO) {
constexpr std::string_view kHloText = R"(
HloModule t
Expand Down

0 comments on commit 4a4d7af

Please sign in to comment.