Skip to content

Commit

Permalink
[Triton] Restricting failing configurations on certain microbenchmark…
Browse files Browse the repository at this point in the history
…s. The failures are all CUDA_ERROR_ILLEGAL_ADDRESS which seem to occur with block_m=16 / block_n=16 with num_warps=16.

PiperOrigin-RevId: 700307430
  • Loading branch information
Moerafaat authored and tensorflower-gardener committed Nov 26, 2024
1 parent 757ce6f commit 7611055
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -910,16 +910,15 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) {
// on small block_k values depending on the bit-width of the inputs to the
// dot. The logic below accounts for this limitation.
constexpr int kLdmatrixGranularity = 256;
if (config.block_k < kLdmatrixGranularity / minBitWidth) {
config.block_k = kLdmatrixGranularity / minBitWidth;

// Additionally, there are further issues happening on FP8 types that
// require additional restriction on block_m to avoid failures similar to
// b/378660935.
if (isF8Dot) {
config.block_m =
std::max(config.block_m, kLdmatrixGranularity / minBitWidth);
}
config.block_k =
std::max(config.block_k, kLdmatrixGranularity / minBitWidth);

// Additionally, there are further issues happening on FP8 types and
// predicates that require additional restriction on block_m when num_warps
// > 8 (see b/378660935). It's unclear if the issue extends beyond these
// cases, so restrictions here are conservative to these.
if ((isF8Dot || minBitWidth == 1) && config.num_warps > 8) {
config.block_m = std::max(config.block_m, 32);
}

// Sparse meta should have at least one element per thread.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,36 @@ ENTRY e {
[](const TritonGemmConfig& config) { return config.block_k > 16; }));
}

// TODO(b/337839570): In addition to Triton's existing limitations on small
// block_k values, there are further issues happening on FP8 types and
// predicates that require additional restriction on block_m when num_warps
// > 8 (see b/378660935). It's unclear if the issue extends beyond these cases,
// so restrictions here are conservative to these.
TEST_F(GemmFusionAutotunerExhaustiveTest, SkipsCrashingConfigsFP8Dot) {
std::unique_ptr<VerifiedHloModule> module = ParseAndReturnVerifiedModule(R"(
HloModule module
ENTRY e {
x = f8e4m3fn[33,33]{1,0} parameter(0)
y = f8e4m3fn[33,33]{1,0} parameter(1)
ROOT out = bf16[33,33]{1,0} dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)")
.value();
const se::CudaComputeCapability compute_capability{
se::CudaComputeCapability::AMPERE, /*minor=*/0};
TF_ASSERT_OK_AND_ASSIGN(
const std::vector<TritonGemmConfig> configs,
GetPossibleMatmulAutotuneTritonConfigs(
*Cast<HloDotInstruction>(
module->entry_computation()->root_instruction()),
compute_capability, GetToolkitVersion(), GetDebugOptionsForTest()));
EXPECT_TRUE(std::all_of(
configs.begin(), configs.end(), [](const TritonGemmConfig& config) {
return config.block_k > 16 &&
(config.num_warps <= 8 || config.block_m > 16);
}));
}

class GemmFusionAutotunerDisableSplitK : public GemmFusionAutotunerTest {
public:
DebugOptions GetDebugOptionsForTest() const override {
Expand Down

0 comments on commit 7611055

Please sign in to comment.