Skip to content

Commit

Permalink
[XLA:GPU] Add intra-warp reduce of reduce test.
Browse files Browse the repository at this point in the history
Add a reproducer from b/380277401 as a test to make sure it doesn't get broken again later.

Reduce op lowering needs special handling if the input parameter has slice layout. The issue [0] was fixed in upstream Triton in June 2024 [1], but later lost and re-fixed in [2].

[0] triton-lang/triton#4116
[1] triton-lang/triton#4139
[2] triton-lang/triton#5080

PiperOrigin-RevId: 700308452
  • Loading branch information
olegshyshkov authored and tensorflower-gardener committed Nov 26, 2024
1 parent 7611055 commit 691aa73
Showing 1 changed file with 41 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1403,6 +1403,47 @@ CHECK: tt.store {{.*}} !tt.ptr<f32>
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{0, 0}));
}

// Reproducer from b/380277401.
TEST_F(TritonEmitterTest, IntraWarpReduceOfReduceIsCorrect) {
const std::string kHloText = R"(
add {
x = s32[] parameter(0)
y = s32[] parameter(1)
ROOT add = s32[] add(x, y)
}
triton_computation {
p = s32[4,8] parameter(0)
bitcast = s32[4,2,4] bitcast(p)
zero = s32[] constant(0)
reduce_1 = s32[4,2] reduce(bitcast, zero), dimensions={2}, to_apply=add
ROOT reduce_2 = s32[2] reduce(reduce_1, zero), dimensions={0}, to_apply=add
}
ENTRY entry_computation {
i = s32[32] iota(), iota_dimension=0
p = s32[4,8] bitcast(i)
ROOT r = s32[2] fusion(p),
kind=kCustom, calls=triton_computation,
backend_config={
"fusion_backend_config":{"kind":"__triton","block_level_fusion_config":
{"output_tile_sizes":["2"],"num_warps":"1"}}}
})";
TF_EXPECT_OK(
CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"(
CHECK: tt.load
CHECK: tt.reshape
CHECK: tt.reduce
CHECK: tt.reduce
CHECK: tt.store
)"));

EXPECT_TRUE(
RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0}));
}

} // namespace
} // namespace gpu
} // namespace xla

0 comments on commit 691aa73

Please sign in to comment.