Skip to content

Commit

Permalink
PR #19372: [GPU] Consider small kInput fusions with concatenations in…
Browse files Browse the repository at this point in the history
… the horizontal loop fusion pass.

Imported from GitHub PR #19372

Copybara import of the project:

--
9990047 by Ilia Sergachev <[email protected]>:

[GPU] Consider small kInput fusions with concatenations in the horizontal loop fusion pass.

Merging this change closes #19372

COPYBARA_INTEGRATE_REVIEW=#19372 from openxla:horizontal_fusion_concat 9990047
PiperOrigin-RevId: 697905107
  • Loading branch information
sergachev authored and Google-ML-Automation committed Nov 19, 2024
1 parent e4eeddb commit a9d7685
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 5 deletions.
13 changes: 8 additions & 5 deletions xla/service/gpu/transforms/horizontal_loop_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ class HorizontalLoopFusionImpl {
std::string prefix_;
}; // HorizontalLoopFusionImpl

bool IsConcatenationInputFusion(const HloInstruction& instr) {
return instr.IsInputFusion() &&
instr.fused_expression_root()->opcode() == HloOpcode::kConcatenate;
}

bool IsFusibleCandidate(const HloInstruction& instr) {
// For now, we do not support fusing instruction with control flow.
if (!instr.control_successors().empty() ||
Expand All @@ -158,8 +163,7 @@ bool IsFusibleCandidate(const HloInstruction& instr) {
return true;
}

// Exclude fusions other than kLoop.
if (!instr.IsLoopFusion()) {
if (!(instr.IsLoopFusion() || IsConcatenationInputFusion(instr))) {
return false;
}

Expand Down Expand Up @@ -196,7 +200,8 @@ bool IsProfitableFusionCandidate(const HloInstruction& instr,
// GPU thread can only process 1 element. From experience, we enable larger
// tensor size threshold for kLoop fusion.
const int64_t kShapeThreshold =
sliced_input_fusion ? 128 * 2048 : 8192 * 8192;
(sliced_input_fusion || IsConcatenationInputFusion(instr)) ? 128 * 2048
: 8192 * 8192;
const int64_t kInstrCountThreshold = sliced_input_fusion ? 30 : 128;
const HloInstruction* root = (instr.opcode() == HloOpcode::kFusion)
? instr.fused_expression_root()
Expand Down Expand Up @@ -253,8 +258,6 @@ void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
std::vector<HloInstruction*> ordered_fusible_candidates;
for (HloInstruction* opnd : consumer->operands()) {
HloInstruction* predecessor = opnd->LatestNonGteAncestor();
// We support kLoop fusion and element-wise HLOs now. We may extend the
// support list if needs arise.
if (IsFusibleCandidate(*predecessor)) {
if (fusible_candidates.insert(predecessor).second) {
// Add unseen fusion to ordered list.
Expand Down
54 changes: 54 additions & 0 deletions xla/service/gpu/transforms/horizontal_loop_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,60 @@ e {
EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value());
}

TEST_F(HorizontalLoopFusionTest, FuseSmallConcatenationInputFusions) {
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
a {
p = s4[1] parameter(0)
q = s4[2] parameter(1)
c = s4[3] concatenate(p, q), dimensions={0}
}
b {
p = s4[4] parameter(0)
q = s4[5] parameter(1)
c = s4[9] concatenate(p, q), dimensions={0}
}
e {
p = s4[1] constant({...})
q = s4[2] constant({...})
x = s4[3] fusion(p, q), kind=kInput, calls=a
r = s4[4] constant({...})
s = s4[5] constant({...})
y = s4[9] fusion(r, s), kind=kInput, calls=b
t = tuple(x, y)
})"));

EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value());
}

TEST_F(HorizontalLoopFusionTest, DoNotFuseLargerConcatenationInputFusions) {
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
a {
p = s4[100000] parameter(0)
q = s4[200000] parameter(1)
c = s4[300000] concatenate(p, q), dimensions={0}
}
b {
p = s4[200000] parameter(0)
q = s4[100000] parameter(1)
c = s4[300000] concatenate(p, q), dimensions={0}
}
e {
p = s4[100000] constant({...})
q = s4[200000] constant({...})
x = s4[300000] fusion(p, q), kind=kInput, calls=a
r = s4[200000] constant({...})
s = s4[100000] constant({...})
y = s4[300000] fusion(r, s), kind=kInput, calls=b
t = tuple(x, y)
})"));

EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value());
}

TEST_F(HorizontalLoopFusionTest, IterativeHorizontalFusion) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule NonfusionInstrs
Expand Down

0 comments on commit a9d7685

Please sign in to comment.