Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SWDEV-492517 #65

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions third_party/triton/temporary/amd_pr7.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp
index b0976f8..bcdc5c7 100644
--- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp
+++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp
@@ -956,6 +956,22 @@ struct FpToFpOpConversion
for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {
inVals.push_back(operands[i][0]);
}
+
+ bool isSrcFP16 = srcElementType.isF16();
+ bool isSrcBF16 = srcElementType.isBF16();
+
+ if ((isSrcFP16 || isSrcBF16)
+ && isDstFP32) {
+ SmallVector<Value> outVals;
+ for (Value &v : inVals) {
+ if(isSrcFP16)
+ outVals.push_back(convertFp16ToFp32(loc, rewriter, v));
+ else
+ outVals.push_back(convertBf16ToFp32(loc, rewriter, v));
+ }
+ return outVals;
+ }
+
if (useFP16IntermediateSrc)
for (Value &v : inVals)
v = cvtFp32ToFp16(loc, rewriter, v,
4 changes: 3 additions & 1 deletion third_party/triton/temporary/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ These are created temporarily and should be moved to the first copybara workflow
internal patch during the next triton integration process.
"""

temporary_patch_list = []
temporary_patch_list = [
"//third_party/triton/temporary:amd_pr7.patch",
]
1 change: 1 addition & 0 deletions third_party/tsl/third_party/gpus/rocm_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,7 @@ def _create_local_rocm_repository(repository_ctx):
"-DTENSORFLOW_USE_ROCM=1",
"-D__HIP_PLATFORM_AMD__",
"-DEIGEN_USE_HIP",
"-DUSE_ROCM",
])

rocm_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
Expand Down
2 changes: 1 addition & 1 deletion xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_partitioning_algorithm(
DebugOptions::PARTITIONING_ALGORITHM_NOOP);

opts.set_xla_gpu_enable_triton_gemm(true);
opts.set_xla_gpu_enable_triton_gemm(false);
opts.set_xla_gpu_enable_cudnn_int8x32_convolution_reordering(true);
opts.set_xla_gpu_triton_gemm_any(false);
opts.set_xla_gpu_enable_triton_softmax_fusion(false);
Expand Down
4 changes: 3 additions & 1 deletion xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -585,12 +585,14 @@ cc_library(
"@triton//:TritonGPUToLLVM",
"@triton//:TritonToTritonGPU",
"@triton//:TritonGPUTransforms",
"@triton//:TritonLLVMIR",
]) + if_cuda_is_configured([
"@triton//third_party/nvidia:NVGPUToLLVM",
"@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM",
"@triton//:TritonLLVMIR",
]) + if_rocm_is_configured([
"@tsl//tsl/platform:rocm_rocdl_path",
"@triton//third_party/amd:TritonAMDGPUToLLVM",
"@triton//third_party/amd:TritonAMDGPUTransforms",
]),
)

Expand Down
7 changes: 4 additions & 3 deletions xla/service/gpu/amdgpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ limitations under the License.
#include "xla/service/gpu/cudnn_fused_conv_rewriter.h"
#include "xla/service/gpu/cusolver_rewriter.h"
#include "xla/service/gpu/gemm_algorithm_picker.h"
#include "xla/service/gpu/gpu_algebraic_simplifier.h"
#include "xla/service/gpu/gpu_compiler.h"
#include "xla/service/gpu/gpu_conv_padding_legalization.h"
#include "xla/service/gpu/gpu_conv_rewriter.h"
Expand Down Expand Up @@ -141,7 +142,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
GetAlgebraicSimplifierOptions(hlo_module->config());
options.set_enable_conv_operand_swap(false);
options.set_enable_unconditional_reduce_of_concat_replacement(false);
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
pipeline.AddPass<HloPassFix<GpuAlgebraicSimplifier>>(options, gpu_version);

// tf2xla bridge, DepthwiseConvolutionConverter, GpuConvRewriter, and
// CudnnSimplifyPadding introduce reshapes and transposes. Run ReshapeMover
Expand All @@ -151,7 +152,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
ReshapeMoverOptions reshape_mover_options;
reshape_mover_options.reshape_of_1d_broadcast_is_cheap = true;
pipeline.AddPass<ReshapeMover>(reshape_mover_options);
pipeline.AddPass<AlgebraicSimplifier>(options);
pipeline.AddPass<GpuAlgebraicSimplifier>(options, gpu_version);
}();

// The reshapes and transposes can possibly be eliminated using
Expand All @@ -162,7 +163,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization(
[&, &pipeline = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
"simplify_after_conv_canonicalization")] {
pipeline.AddPass<ConvertMover>();
pipeline.AddPass<AlgebraicSimplifier>(options);
pipeline.AddPass<GpuAlgebraicSimplifier>(options, gpu_version);
}();

// GpuConvRewriter, GpuConvPaddingLegalization and
Expand Down
3 changes: 3 additions & 0 deletions xla/service/gpu/fusions/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
triton_config.set_block_k(64);
triton_config.set_block_n(64);
triton_config.set_split_k(1);
triton_config.set_num_stages(1);
triton_config.set_num_warps(2);
triton_config.set_num_ctas(1);

block_level_parameters.num_ctas = 1;
block_level_parameters.num_stages = 1;
Expand Down
8 changes: 5 additions & 3 deletions xla/service/gpu/gemm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -801,10 +801,12 @@ absl::StatusOr<bool> GemmFusion::Run(
const absl::flat_hash_set<absl::string_view>& execution_threads) {
auto cuda_compute_capability =
std::get_if<se::CudaComputeCapability>(&gpu_version_);
if (!cuda_compute_capability) {
auto rocm_compute_capability =
std::get_if<se::RocmComputeCapability>(&gpu_version_);
if (!cuda_compute_capability && !rocm_compute_capability) {
return absl::FailedPreconditionError(
"Triton support is only enabled for CUDA GPUs.");
} else if (!cuda_compute_capability->IsAtLeastAmpere()) {
"Triton support is only enabled for CUDA and ROCM GPUs.");
} else if (cuda_compute_capability && !cuda_compute_capability->IsAtLeastAmpere()) {
return absl::FailedPreconditionError(
absl::StrCat("Triton support is only enabled for Ampere GPUs (compute ",
"capability 8.0) and up, but got compute capability ",
Expand Down
7 changes: 5 additions & 2 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1340,13 +1340,16 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
gpu_target_config.device_description.gpu_compute_capability();
pipeline.AddPass<AlgorithmChecker>(gpu_version);
const auto* cuda_cc = std::get_if<se::CudaComputeCapability>(&gpu_version);
const auto* rocm_cc = std::get_if<se::RocmComputeCapability>(&gpu_version);

// Rewrite FP8 GEMMs ahead of Triton which currently lacks support for FP8
// and may rewrite quantized FP8 GEMMs as higher-precision GEMMs.
pipeline.AddPass<GemmRewriter>(gpu_version, GetToolkitVersion(),
/*f8_rewrite=*/true);
if (debug_options.xla_gpu_enable_triton_gemm() && cuda_cc != nullptr &&
cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) {
if (debug_options.xla_gpu_enable_triton_gemm() &&
((cuda_cc != nullptr &&
cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) ||
rocm_cc != nullptr)) {
pipeline.AddPass<GemvRewriter>();
pipeline.AddPass<GemmFusion>(gpu_version);
}
Expand Down
51 changes: 31 additions & 20 deletions xla/service/gpu/ir_emitter_triton_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ limitations under the License.
==============================================================================*/
// TODO(ROCm): Enable and include ROCm Triton passes when ROCm Triton is
// included in build.
// #include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h"
#include "third_party/amd/include/TritonAMDGPUToLLVM/Passes.h"
#include "third_party/amd/include/TritonAMDGPUTransforms/Passes.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" // from @llvm-project
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" // from @llvm-project
Expand All @@ -35,6 +36,10 @@ limitations under the License.
namespace xla {
namespace gpu {

// Value 0 for num_stages is used to represent AMD specific register
// file double buffering.
constexpr int kAmdDoubleBuffering = 0;

namespace ma = ::mlir::arith;
namespace mm = ::mlir::math;
namespace ml = ::mlir::LLVM;
Expand All @@ -55,9 +60,10 @@ absl::Status CreateTritonPipeline(
const int ccAsInt = 0;
// TODO(ROCm): Check why some test fail when threadsPerWarp is set to 64.
const int threadsPerWarp = 32;
auto ccRocm = std::get<se::RocmComputeCapability>(cc);

// Based on make_ttir() in
// @triton//:third_party/nvidia/backend/compiler.py
// @triton//:third_party/amd/backend/compiler.py
pm.addPass(mlir::createInlinerPass());
pm.addPass(mt::createRewriteTensorPointerPass());
pm.addPass(mt::createCombineOpsPass());
Expand All @@ -68,46 +74,51 @@ absl::Status CreateTritonPipeline(
pm.addPass(mlir::createSymbolDCEPass());

// Based on make_ttgir() in
// @triton//:third_party/nvidia/backend/compiler.py
// @triton//:third_party/amd/backend/compiler.py
pm.addPass(mt::createConvertTritonToTritonGPUPass(
absl::StrFormat("cuda:%u", ccAsInt), block_level_parameters.num_warps,
threadsPerWarp, block_level_parameters.num_ctas));
absl::StrCat("hip:", ccRocm.gfx_version()),
block_level_parameters.num_warps, threadsPerWarp,
block_level_parameters.num_ctas));
pm.addPass(mt::gpu::createTritonGPUCoalesce());
pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
pm.addPass(mt::gpu::createTritonGPUOptimizeThreadLocality());
pm.addPass(mt::gpu::createTritonGPUAccelerateMatmul());
pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
// TODO ROCm Check if we want to compare MI100 and greater
pm.addPass(mlir::createTritonAMDGPUOptimizeEpiloguePass());
pm.addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true}));
pm.addPass(mlir::createCSEPass());
pm.addPass(
mt::gpu::createTritonGPUPipeline({block_level_parameters.num_stages}));
pm.addPass(mt::gpu::createTritonGPUPrefetch());

// TODO ROCm Check if we want to compare MI100 and greater
if (block_level_parameters.num_stages == kAmdDoubleBuffering &&
ccRocm.has_amd_matrix_core()) {
pm.addPass(mlir::createTritonAMDGPUStreamPipelinePass());
pm.addPass(mlir::createCanonicalizerPass());
}
pm.addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true}));
pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions());
pm.addPass(mt::gpu::createTritonGPUReduceDataDuplication());
pm.addPass(mt::gpu::createTritonGPUReorderInstructions());
if (block_level_parameters.num_stages != kAmdDoubleBuffering) {
pm.addPass(mt::gpu::createTritonGPUReorderInstructions());
}
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
pm.addPass(mlir::createCanonicalizerPass());

// Based on make_llir() in
// @triton//:third_party/nvidia/backend/compiler.py
// pm.addPass(mt::gpu::createDecomposeUnsupportedConversionsPass());
// @triton//:third_party/amd/backend/compiler.py
pm.addPass(mlir::triton::AMD::createDecomposeUnsupportedConversionsPass(
ccRocm.gfx_version()));
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createConvertIndexToLLVMPass());
pm.addPass(mt::gpu::createAllocateSharedMemoryPass());
// pm.addPass(mt::createConvertTritonAMDGPUToLLVMPass());
pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(
mt::createConvertTritonAMDGPUToLLVMPass(ccRocm.gfx_version(), true));
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
// Note: translateTritonGPUToLLVMIR adds line info with LLVMDIScopePass.
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createConvertControlFlowToLLVMPass());

pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
pm.addPass(mt::createConvertBuiltinFuncToLLVMPass());
// There is no clusters in ROCm for now.
out_cluster_info.clusterDimX = 1;
out_cluster_info.clusterDimY = 1;
Expand Down
23 changes: 23 additions & 0 deletions xla/service/gpu/ir_emitter_triton_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class TritonGemmTest : public TritonTest {
debug_options.set_xla_gpu_enable_split_k_autotuning(false);
// Always rewrite Gemms with Triton regardless of size.
debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0);
debug_options.set_xla_gpu_enable_triton_gemm(true);
return debug_options;
}

Expand Down Expand Up @@ -2414,6 +2415,9 @@ ENTRY e {

TEST_F(TritonGemmTestAny,
DoNotFuseConcatenationOfSplitNonContractingDimension) {
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
GTEST_SKIP() << "Not using autotuner on ROCM yet.";
}
if (SkipBF16Tests()) {
GTEST_SKIP() << "BF16 not supported.";
}
Expand Down Expand Up @@ -3235,6 +3239,10 @@ TEST_F(TritonGemmLevel2Test, SplitLHSInputOutputIsFused) {
if (SkipBF16Tests()) {
GTEST_SKIP() << "BF16 not supported.";
}
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
GTEST_SKIP() << "Skipped until corresponding issue on ROCm is fixed.";
}

const std::string kHloText = R"(
ENTRY e {
p0t = (s8[5,18,20,150]) parameter(0)
Expand Down Expand Up @@ -3306,6 +3314,9 @@ ENTRY e {

TEST_F(TritonGemmTestAny,
LowerDotWithLhsWithoutNonContractingDimThroughTriton) {
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
GTEST_SKIP() << "Not enough memory to allocate on ROCM.";
}
const std::string hlo_text = R"(
HloModule t

Expand All @@ -3328,6 +3339,9 @@ ENTRY e {

TEST_F(TritonGemmTestAny,
LowerDotWithRhsWithoutNonContractingDimThroughTriton) {
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
GTEST_SKIP() << "Not enough memory to allocate on ROCM.";
}
const std::string hlo_text = R"(
HloModule t

Expand Down Expand Up @@ -3565,6 +3579,9 @@ ENTRY e {
}

TEST_F(CompareTest, UsingOptinSharedMemoryOnAmpereProducesSameResult) {
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
GTEST_SKIP() << "No Optin Shared Memory on AMD.";
}
const se::DeviceDescription dev_info =
backend().default_stream_executor()->GetDeviceDescription();
constexpr int kBytesOfSharedMemoryTested = 64 * 1024;
Expand Down Expand Up @@ -5011,6 +5028,9 @@ CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16>
}

TEST_F(Triton6xBF16GemmTest, Emit6xBF16GemmEndToEnd) {
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
GTEST_SKIP() << "ALG_DOT_BF16_BF16_F32_X6 not supported on ROCM.";
}
const char* kHloText = R"(
HloModule t

Expand Down Expand Up @@ -5347,6 +5367,9 @@ CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16>
}

TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmEndToEnd) {
if (std::holds_alternative<se::RocmComputeCapability>(GpuComputeComp())) {
GTEST_SKIP() << "ALG_DOT_BF16_BF16_F32_X3 not supported on ROCM.";
}
const char* kHloText = R"(
HloModule t

Expand Down
6 changes: 3 additions & 3 deletions xla/service/gpu/ir_emitter_unnested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1588,8 +1588,8 @@ absl::Status IrEmitterUnnested::EmitTopKCustomCall(

absl::Status IrEmitterUnnested::EmitTritonCustomCall(
const HloCustomCallInstruction* instr) {
#if !GOOGLE_CUDA
return absl::UnimplementedError("Triton support requires CUDA");
#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM
return absl::UnimplementedError("Triton support requires CUDA or ROCm");
#else
auto generate = [this, &instr]() -> absl::StatusOr<KernelReuseCache::Entry> {
mlir::MLIRContext& mlir_context = *ir_emitter_context_->mlir_context();
Expand Down Expand Up @@ -1617,7 +1617,7 @@ absl::Status IrEmitterUnnested::EmitTritonCustomCall(
TF_ASSIGN_OR_RETURN(
auto result,
CompileTritonToLLVM(hlo_module->config(), hlo_module->name(),
ir_emitter_context_->cuda_compute_capability(),
ir_emitter_context_->gpu_compute_capability(),
ir_emitter_context_->gpu_device_info(),
block_level_parameters, triton_module.get(),
ir_emitter_context_->llvm_module(), mlir_context));
Expand Down
5 changes: 5 additions & 0 deletions xla/stream_executor/device_description.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ class RocmComputeCapability {

bool has_mfma_instr_support() const { return gfx9_mi100_or_later(); }

bool has_amd_matrix_core() const {
return (gfx9_mi100_or_later() || gfx_version().find("gfx11") ||
gfx_version().find("gfx12"));
}

bool has_fp16_atomics_support() const {
// TODO(rocm): Check. This should be the same as has_fast_fp16_support().
return gfx9_mi200_or_later();
Expand Down
Loading