Skip to content

Commit

Permalink
PR #18838: [NVIDIA GPU] Support multi-operand collective-permute
Browse files Browse the repository at this point in the history
Imported from GitHub PR #18838

For collective-permutes with small message sizes, it is beneficial to combine them into a single collective because
1. it gets rid of some kernel launch overhead, and allows NCCL to do some message fusion;
2. fewer collectives make it easier for LHS to make better decision.

In order to support combining collective-permutes, we need to support multi-operand collective-permute first, a.k.a. the combined collective-permute. This PR extends the existing CP interface by overloading it, so that a CP can have multiple operands.
Copybara import of the project:

--
5e10aba by Terry Sun <[email protected]>:

support multi-operand cp

--
170fead by Terry Sun <[email protected]>:

minor refactoring

--
0d85070 by Terry Sun <[email protected]>:

update python interface

--
9812a10 by Terry Sun <[email protected]>:

polish python interface

--
3a1552c by Terry Sun <[email protected]>:

formatting

--
d3657f8 by Terry Sun <[email protected]>:

formatting

--
c9202fa by Terry Sun <[email protected]>:

fix minor issues

Merging this change closes #18838

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18838 from terryysun:terryysun/grouped_cp c9202fa
PiperOrigin-RevId: 693728463
  • Loading branch information
terryysun authored and Google-ML-Automation committed Nov 19, 2024
1 parent e0300b0 commit 599290f
Show file tree
Hide file tree
Showing 26 changed files with 653 additions and 161 deletions.
46 changes: 28 additions & 18 deletions xla/hlo/analysis/hlo_dataflow_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1111,26 +1111,32 @@ bool HloDataflowAnalysis::UpdateCollectivePermuteStartValueSet(
bool changed = false;
// CollectivePermuteStart forwards the operand value to element {0} of its
// output.
if (collective_permute_start->operand(0)->shape().IsTuple()) {
for (int i = 0; i < ShapeUtil::TupleElementCount(
collective_permute_start->operand(0)->shape());
++i) {
for (int oprd_idx = 0; oprd_idx < collective_permute_start->operands().size();
++oprd_idx) {
if (collective_permute_start->operand(oprd_idx)->shape().IsTuple()) {
for (int i = 0;
i < ShapeUtil::TupleElementCount(
collective_permute_start->operand(oprd_idx)->shape());
++i) {
const HloValueSet& operand_value_set =
GetValueSet(collective_permute_start->operand(oprd_idx), {i});
HloValueSet& value_set =
GetValueSet(collective_permute_start, {0, oprd_idx, i});
if (value_set != operand_value_set) {
value_set = operand_value_set;
changed = true;
}
}
} else {
const HloValueSet& operand_value_set =
GetValueSet(collective_permute_start->operand(0), {i});
HloValueSet& value_set = GetValueSet(collective_permute_start, {0, i});
GetValueSet(collective_permute_start->operand(oprd_idx));
HloValueSet& value_set =
GetValueSet(collective_permute_start, {0, oprd_idx});
if (value_set != operand_value_set) {
value_set = operand_value_set;
changed = true;
}
}
} else {
const HloValueSet& operand_value_set =
GetValueSet(collective_permute_start->operand(0));
HloValueSet& value_set = GetValueSet(collective_permute_start, {0});
if (value_set != operand_value_set) {
value_set = operand_value_set;
changed = true;
}
}
return changed;
}
Expand Down Expand Up @@ -1579,16 +1585,16 @@ absl::Status HloDataflowAnalysis::InitializeInstructionValueSets() {
// AllReduceDone's output aliases its input.
break;
case HloOpcode::kCollectivePermuteStart:
// CollectivePermuteStart produces a tuple of
// {aliased operand, destination buffer, contexts}, where the context
// data are optional.
// CollectivePermuteStart produces a tuple of {{aliased operand(s)},
// {destination buffer(s)}, contexts}, where the context data are
// optional.
define_value_at(/*index=*/{});
define_value_at(/*index=*/{1});
for (int i = 2; i < instruction->shape().tuple_shapes_size(); ++i) {
define_value_at(/*index=*/{i});
}

if (instruction->operand_count() > 1) {
if (Cast<HloCollectivePermuteInstruction>(instruction)->inplace()) {
CHECK_EQ(instruction->operand_count(), 4);
if (instruction->operand(1)->shape().IsTuple()) {
for (int i = 0; i < ShapeUtil::TupleElementCount(
Expand All @@ -1597,6 +1603,10 @@ absl::Status HloDataflowAnalysis::InitializeInstructionValueSets() {
define_value_at(/*index=*/{1, i});
}
}
} else if (instruction->operand_count() > 1) {
for (int i = 0; i < instruction->operand_count(); ++i) {
define_value_at(/*index=*/{1, i});
}
}
break;
case HloOpcode::kCollectivePermuteDone:
Expand Down
46 changes: 46 additions & 0 deletions xla/hlo/analysis/hlo_dataflow_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2155,6 +2155,52 @@ TEST_F(HloDataflowAnalysisTest, AllReduceStartAndDoneTwoOperands) {
UnorderedElementsAre(HloUse{done, 0, {}}));
}

TEST_F(HloDataflowAnalysisTest, CombinedCollectivePermuteStartAndDone) {
const char* hlo_text = R"(
HloModule test
ENTRY entry {
p0 = f32[2] parameter(0)
p1 = f32[2] parameter(1)
start = ((f32[2], f32[2]), (f32[2], f32[2])) collective-permute-start(p0, p1), source_target_pairs={{0,1},{1,0}}
ROOT done = (f32[2], f32[2]) collective-permute-done(start)
}
)";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
const HloDataflowAnalysis& analysis = RunAnalysis(/*ssa_form=*/false);
absl::Status status = analysis.Verify();
EXPECT_TRUE(status.ok()) << status;

HloInstruction* done = module_->entry_computation()->root_instruction();
HloInstruction* start = done->mutable_operand(0);
HloInstruction* param0 = start->mutable_operand(0);
HloInstruction* param1 = start->mutable_operand(1);

EXPECT_TRUE(analysis.ValueIsDefinedAt(start, /*index=*/{}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(start, /*index=*/{1}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(start, /*index=*/{1, 0}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(start, /*index=*/{1, 1}));

EXPECT_TRUE(analysis.ValueIsDefinedAt(done, /*index=*/{}));
EXPECT_FALSE(analysis.ValueIsDefinedAt(done, /*index=*/{0}));
EXPECT_FALSE(analysis.ValueIsDefinedAt(done, /*index=*/{1}));

EXPECT_THAT(
analysis.GetValueDefinedAt(param0).GetUses(),
UnorderedElementsAre(HloUse{start, 0, {}}, HloUse{done, 0, {0, 0}}));
EXPECT_THAT(
analysis.GetValueDefinedAt(param1).GetUses(),
UnorderedElementsAre(HloUse{start, 1, {}}, HloUse{done, 0, {0, 1}}));

EXPECT_THAT(HloValuesAt(start, /*index=*/{0, 0}),
UnorderedElementsAre(&analysis.GetValueDefinedAt(param0, {})));
EXPECT_THAT(HloValuesAt(start, /*index=*/{0, 1}),
UnorderedElementsAre(&analysis.GetValueDefinedAt(param1, {})));
EXPECT_THAT(HloValuesAt(done, /*index=*/{0}),
UnorderedElementsAre(&analysis.GetValueDefinedAt(start, {1, 0})));
EXPECT_THAT(HloValuesAt(done, /*index=*/{1}),
UnorderedElementsAre(&analysis.GetValueDefinedAt(start, {1, 1})));
}

TEST_F(HloDataflowAnalysisTest, AllGatherStartAndDoneWithTuple) {
const char* hlo_text = R"(
HloModule test
Expand Down
86 changes: 78 additions & 8 deletions xla/hlo/builder/xla_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,19 @@ XlaOp XlaBuilderFriend::BuildCopyDone(XlaBuilder* builder, const XlaOp operand,
XlaOp XlaBuilderFriend::BuildCollectivePermuteStart(
XlaBuilder* builder, XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id) {
const std::optional<ChannelHandle>& channel_id, const bool inplace) {
return builder->CollectivePermuteImpl(operand, source_target_pairs,
channel_id, /*async=*/true);
channel_id, /*async=*/true, inplace);
}

XlaOp XlaBuilderFriend::BuildCollectivePermuteStart(
XlaBuilder* builder, absl::Span<const XlaOp> operands,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id, const bool inplace) {
// TODO support multi-operand in-place collective permute
CHECK(!inplace);
return builder->CollectivePermuteImpl(operands, source_target_pairs,
channel_id, /*async=*/true, inplace);
}

XlaOp XlaBuilderFriend::BuildCollectivePermuteDone(XlaBuilder* builder,
Expand Down Expand Up @@ -4083,21 +4093,32 @@ XlaOp XlaBuilder::CollectiveBroadcastImpl(
XlaOp XlaBuilder::CollectivePermute(
XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id) {
const std::optional<ChannelHandle>& channel_id, const bool inplace) {
return CollectivePermuteImpl(operand, source_target_pairs, channel_id,
/*async=*/false);
/*async=*/false, inplace);
}

XlaOp XlaBuilder::CollectivePermute(
absl::Span<const XlaOp> operands,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id, const bool inplace) {
// TODO support multi-operand in-place collective permute
CHECK(!inplace);
return CollectivePermuteImpl(operands, source_target_pairs, channel_id,
/*async=*/false, inplace);
}

XlaOp XlaBuilder::CollectivePermuteImpl(
XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id, bool async) {
const std::optional<ChannelHandle>& channel_id, bool async,
const bool inplace) {
return ReportErrorOrReturn([&]() -> absl::StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(
Shape shape,
ShapeInference::InferCollectivePermuteShape({operand_shape}));
ShapeInference::InferCollectivePermuteShape(operand_shape, inplace));
*instr.mutable_shape() = shape.ToProto();

for (const auto& pair : source_target_pairs) {
Expand All @@ -4116,6 +4137,45 @@ XlaOp XlaBuilder::CollectivePermuteImpl(
});
}

XlaOp XlaBuilder::CollectivePermuteImpl(
absl::Span<const XlaOp> operands,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id, bool async,
const bool inplace) {
// TODO support multi-operand in-place collective permute
CHECK(!inplace);
return ReportErrorOrReturn([&]() -> absl::StatusOr<XlaOp> {
std::vector<const Shape*> operand_shapes;
for (const auto& operand : operands) {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
operand_shapes.push_back(operand_shape);
}
CHECK_GT(operand_shapes.size(), 1);
HloInstructionProto instr;
auto tuple_operand_shapes =
ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes);
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferCollectivePermuteShape(
&tuple_operand_shapes, inplace));
*instr.mutable_shape() =
ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes).ToProto();

for (const auto& pair : source_target_pairs) {
auto* proto_pair = instr.add_source_target_pairs();
proto_pair->set_source(pair.first);
proto_pair->set_target(pair.second);
}
if (channel_id.has_value()) {
instr.set_channel_id(channel_id->handle());
}

return AddInstruction(std::move(instr),
async ? HloOpcode::kCollectivePermuteStart
: HloOpcode::kCollectivePermute,
operands);
});
}

XlaOp XlaBuilder::ReplicaId() {
return ReportErrorOrReturn([&]() -> absl::StatusOr<XlaOp> {
HloInstructionProto instr;
Expand Down Expand Up @@ -5630,9 +5690,19 @@ XlaOp CollectiveBroadcast(const XlaOp operand,
XlaOp CollectivePermute(
const XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id) {
const std::optional<ChannelHandle>& channel_id, const bool inplace) {
return operand.builder()->CollectivePermute(operand, source_target_pairs,
channel_id);
channel_id, inplace);
}

XlaOp MultiCollectivePermute(
absl::Span<const XlaOp> operands,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id, const bool inplace) {
// TODO support multi-operand in-place collective permute
CHECK(!inplace);
return operands.at(0).builder()->CollectivePermute(
operands, source_target_pairs, channel_id, inplace);
}

XlaOp ReplicaId(XlaBuilder* builder) { return builder->ReplicaId(); }
Expand Down
40 changes: 35 additions & 5 deletions xla/hlo/builder/xla_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,13 @@ struct XlaBuilderFriend {
static XlaOp BuildCollectivePermuteStart(
XlaBuilder* builder, XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id = std::nullopt);
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const bool inplace = false);
static XlaOp BuildCollectivePermuteStart(
XlaBuilder* builder, absl::Span<const XlaOp> operands,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const bool inplace = false);
static XlaOp BuildCollectivePermuteDone(XlaBuilder* builder, XlaOp operands,
const Shape& shape);

Expand Down Expand Up @@ -880,7 +886,14 @@ class XlaBuilder {
XlaOp CollectivePermute(
XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id = std::nullopt);
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const bool inplace = false);

XlaOp CollectivePermute(
absl::Span<const XlaOp> operands,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const bool inplace = false);

XlaOp ReplicaId();

Expand Down Expand Up @@ -1562,7 +1575,11 @@ class XlaBuilder {
friend XlaOp CollectivePermute(
XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id);
const std::optional<ChannelHandle>& channel_id, const bool inplace);
friend XlaOp MultiCollectivePermute(
absl::Span<const XlaOp> operands,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id, const bool inplace);
friend XlaOp ReplicaId(XlaBuilder* builder);
friend XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
absl::Span<const int64_t> window_dimensions,
Expand Down Expand Up @@ -1714,7 +1731,14 @@ class XlaBuilder {
XlaOp CollectivePermuteImpl(
XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id, bool async);
const std::optional<ChannelHandle>& channel_id, bool async,
const bool inplace);

XlaOp CollectivePermuteImpl(
absl::Span<const XlaOp> operands,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id, bool async,
const bool inplace);

XlaOp ConditionalImpl(
XlaOp branch_index,
Expand Down Expand Up @@ -2647,7 +2671,13 @@ XlaOp CollectiveBroadcast(
XlaOp CollectivePermute(
XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id = std::nullopt);
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const bool inplace = false);
XlaOp MultiCollectivePermute(
absl::Span<const XlaOp> operands,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const bool inplace = false);

// Enqueues an operation that returns the replica ID.
XlaOp ReplicaId(XlaBuilder* builder);
Expand Down
9 changes: 9 additions & 0 deletions xla/hlo/builder/xla_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,15 @@ TEST(XlaBuilderTest, CollectivePermute) {
EXPECT_EQ(GetRoot(*module)->opcode(), HloOpcode::kCollectivePermute);
}

TEST(XlaBuilderTest, CombinedCollectivePermute) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {5, 7}), "y");
MultiCollectivePermute({x, y}, {{0, 1}, {1, 2}, {2, 3}});
TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b));
EXPECT_EQ(GetRoot(*module)->opcode(), HloOpcode::kCollectivePermute);
}

TEST(XlaBuilderTest, GetDimensionSize) {
XlaBuilder b(TestName());
auto x =
Expand Down
Loading

0 comments on commit 599290f

Please sign in to comment.