Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #18838: [NVIDIA GPU] Support multi-operand collective-permute #19424

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions xla/hlo/analysis/hlo_dataflow_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1111,26 +1111,32 @@ bool HloDataflowAnalysis::UpdateCollectivePermuteStartValueSet(
bool changed = false;
// CollectivePermuteStart forwards the operand value to element {0} of its
// output.
if (collective_permute_start->operand(0)->shape().IsTuple()) {
for (int i = 0; i < ShapeUtil::TupleElementCount(
collective_permute_start->operand(0)->shape());
++i) {
for (int oprd_idx = 0; oprd_idx < collective_permute_start->operands().size();
++oprd_idx) {
if (collective_permute_start->operand(oprd_idx)->shape().IsTuple()) {
for (int i = 0;
i < ShapeUtil::TupleElementCount(
collective_permute_start->operand(oprd_idx)->shape());
++i) {
const HloValueSet& operand_value_set =
GetValueSet(collective_permute_start->operand(oprd_idx), {i});
HloValueSet& value_set =
GetValueSet(collective_permute_start, {0, oprd_idx, i});
if (value_set != operand_value_set) {
value_set = operand_value_set;
changed = true;
}
}
} else {
const HloValueSet& operand_value_set =
GetValueSet(collective_permute_start->operand(0), {i});
HloValueSet& value_set = GetValueSet(collective_permute_start, {0, i});
GetValueSet(collective_permute_start->operand(oprd_idx));
HloValueSet& value_set =
GetValueSet(collective_permute_start, {0, oprd_idx});
if (value_set != operand_value_set) {
value_set = operand_value_set;
changed = true;
}
}
} else {
const HloValueSet& operand_value_set =
GetValueSet(collective_permute_start->operand(0));
HloValueSet& value_set = GetValueSet(collective_permute_start, {0});
if (value_set != operand_value_set) {
value_set = operand_value_set;
changed = true;
}
}
return changed;
}
Expand Down Expand Up @@ -1579,16 +1585,16 @@ absl::Status HloDataflowAnalysis::InitializeInstructionValueSets() {
// AllReduceDone's output aliases its input.
break;
case HloOpcode::kCollectivePermuteStart:
// CollectivePermuteStart produces a tuple of
// {aliased operand, destination buffer, contexts}, where the context
// data are optional.
// CollectivePermuteStart produces a tuple of {{aliased operand(s)},
// {destination buffer(s)}, contexts}, where the context data are
// optional.
define_value_at(/*index=*/{});
define_value_at(/*index=*/{1});
for (int i = 2; i < instruction->shape().tuple_shapes_size(); ++i) {
define_value_at(/*index=*/{i});
}

if (instruction->operand_count() > 1) {
if (Cast<HloCollectivePermuteInstruction>(instruction)->inplace()) {
CHECK_EQ(instruction->operand_count(), 4);
if (instruction->operand(1)->shape().IsTuple()) {
for (int i = 0; i < ShapeUtil::TupleElementCount(
Expand All @@ -1597,6 +1603,10 @@ absl::Status HloDataflowAnalysis::InitializeInstructionValueSets() {
define_value_at(/*index=*/{1, i});
}
}
} else if (instruction->operand_count() > 1) {
for (int i = 0; i < instruction->operand_count(); ++i) {
define_value_at(/*index=*/{1, i});
}
}
break;
case HloOpcode::kCollectivePermuteDone:
Expand Down
46 changes: 46 additions & 0 deletions xla/hlo/analysis/hlo_dataflow_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2155,6 +2155,52 @@ TEST_F(HloDataflowAnalysisTest, AllReduceStartAndDoneTwoOperands) {
UnorderedElementsAre(HloUse{done, 0, {}}));
}

TEST_F(HloDataflowAnalysisTest, CombinedCollectivePermuteStartAndDone) {
const char* hlo_text = R"(
HloModule test
ENTRY entry {
p0 = f32[2] parameter(0)
p1 = f32[2] parameter(1)
start = ((f32[2], f32[2]), (f32[2], f32[2])) collective-permute-start(p0, p1), source_target_pairs={{0,1},{1,0}}
ROOT done = (f32[2], f32[2]) collective-permute-done(start)
}
)";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
const HloDataflowAnalysis& analysis = RunAnalysis(/*ssa_form=*/false);
absl::Status status = analysis.Verify();
EXPECT_TRUE(status.ok()) << status;

HloInstruction* done = module_->entry_computation()->root_instruction();
HloInstruction* start = done->mutable_operand(0);
HloInstruction* param0 = start->mutable_operand(0);
HloInstruction* param1 = start->mutable_operand(1);

EXPECT_TRUE(analysis.ValueIsDefinedAt(start, /*index=*/{}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(start, /*index=*/{1}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(start, /*index=*/{1, 0}));
EXPECT_TRUE(analysis.ValueIsDefinedAt(start, /*index=*/{1, 1}));

EXPECT_TRUE(analysis.ValueIsDefinedAt(done, /*index=*/{}));
EXPECT_FALSE(analysis.ValueIsDefinedAt(done, /*index=*/{0}));
EXPECT_FALSE(analysis.ValueIsDefinedAt(done, /*index=*/{1}));

EXPECT_THAT(
analysis.GetValueDefinedAt(param0).GetUses(),
UnorderedElementsAre(HloUse{start, 0, {}}, HloUse{done, 0, {0, 0}}));
EXPECT_THAT(
analysis.GetValueDefinedAt(param1).GetUses(),
UnorderedElementsAre(HloUse{start, 1, {}}, HloUse{done, 0, {0, 1}}));

EXPECT_THAT(HloValuesAt(start, /*index=*/{0, 0}),
UnorderedElementsAre(&analysis.GetValueDefinedAt(param0, {})));
EXPECT_THAT(HloValuesAt(start, /*index=*/{0, 1}),
UnorderedElementsAre(&analysis.GetValueDefinedAt(param1, {})));
EXPECT_THAT(HloValuesAt(done, /*index=*/{0}),
UnorderedElementsAre(&analysis.GetValueDefinedAt(start, {1, 0})));
EXPECT_THAT(HloValuesAt(done, /*index=*/{1}),
UnorderedElementsAre(&analysis.GetValueDefinedAt(start, {1, 1})));
}

TEST_F(HloDataflowAnalysisTest, AllGatherStartAndDoneWithTuple) {
const char* hlo_text = R"(
HloModule test
Expand Down
86 changes: 78 additions & 8 deletions xla/hlo/builder/xla_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,19 @@ XlaOp XlaBuilderFriend::BuildCopyDone(XlaBuilder* builder, const XlaOp operand,
XlaOp XlaBuilderFriend::BuildCollectivePermuteStart(
XlaBuilder* builder, XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id) {
const std::optional<ChannelHandle>& channel_id, const bool inplace) {
return builder->CollectivePermuteImpl(operand, source_target_pairs,
channel_id, /*async=*/true);
channel_id, /*async=*/true, inplace);
}

XlaOp XlaBuilderFriend::BuildCollectivePermuteStart(
XlaBuilder* builder, absl::Span<const XlaOp> operands,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id, const bool inplace) {
// TODO support multi-operand in-place collective permute
CHECK(!inplace);
return builder->CollectivePermuteImpl(operands, source_target_pairs,
channel_id, /*async=*/true, inplace);
}

XlaOp XlaBuilderFriend::BuildCollectivePermuteDone(XlaBuilder* builder,
Expand Down Expand Up @@ -4083,21 +4093,32 @@ XlaOp XlaBuilder::CollectiveBroadcastImpl(
XlaOp XlaBuilder::CollectivePermute(
XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id) {
const std::optional<ChannelHandle>& channel_id, const bool inplace) {
return CollectivePermuteImpl(operand, source_target_pairs, channel_id,
/*async=*/false);
/*async=*/false, inplace);
}

XlaOp XlaBuilder::CollectivePermute(
absl::Span<const XlaOp> operands,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id, const bool inplace) {
// TODO support multi-operand in-place collective permute
CHECK(!inplace);
return CollectivePermuteImpl(operands, source_target_pairs, channel_id,
/*async=*/false, inplace);
}

XlaOp XlaBuilder::CollectivePermuteImpl(
XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id, bool async) {
const std::optional<ChannelHandle>& channel_id, bool async,
const bool inplace) {
return ReportErrorOrReturn([&]() -> absl::StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(
Shape shape,
ShapeInference::InferCollectivePermuteShape({operand_shape}));
ShapeInference::InferCollectivePermuteShape(operand_shape, inplace));
*instr.mutable_shape() = shape.ToProto();

for (const auto& pair : source_target_pairs) {
Expand All @@ -4116,6 +4137,45 @@ XlaOp XlaBuilder::CollectivePermuteImpl(
});
}

XlaOp XlaBuilder::CollectivePermuteImpl(
absl::Span<const XlaOp> operands,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id, bool async,
const bool inplace) {
// TODO support multi-operand in-place collective permute
CHECK(!inplace);
return ReportErrorOrReturn([&]() -> absl::StatusOr<XlaOp> {
std::vector<const Shape*> operand_shapes;
for (const auto& operand : operands) {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
operand_shapes.push_back(operand_shape);
}
CHECK_GT(operand_shapes.size(), 1);
HloInstructionProto instr;
auto tuple_operand_shapes =
ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes);
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferCollectivePermuteShape(
&tuple_operand_shapes, inplace));
*instr.mutable_shape() =
ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes).ToProto();

for (const auto& pair : source_target_pairs) {
auto* proto_pair = instr.add_source_target_pairs();
proto_pair->set_source(pair.first);
proto_pair->set_target(pair.second);
}
if (channel_id.has_value()) {
instr.set_channel_id(channel_id->handle());
}

return AddInstruction(std::move(instr),
async ? HloOpcode::kCollectivePermuteStart
: HloOpcode::kCollectivePermute,
operands);
});
}

XlaOp XlaBuilder::ReplicaId() {
return ReportErrorOrReturn([&]() -> absl::StatusOr<XlaOp> {
HloInstructionProto instr;
Expand Down Expand Up @@ -5630,9 +5690,19 @@ XlaOp CollectiveBroadcast(const XlaOp operand,
XlaOp CollectivePermute(
const XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id) {
const std::optional<ChannelHandle>& channel_id, const bool inplace) {
return operand.builder()->CollectivePermute(operand, source_target_pairs,
channel_id);
channel_id, inplace);
}

XlaOp MultiCollectivePermute(
absl::Span<const XlaOp> operands,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id, const bool inplace) {
// TODO support multi-operand in-place collective permute
CHECK(!inplace);
return operands.at(0).builder()->CollectivePermute(
operands, source_target_pairs, channel_id, inplace);
}

XlaOp ReplicaId(XlaBuilder* builder) { return builder->ReplicaId(); }
Expand Down
40 changes: 35 additions & 5 deletions xla/hlo/builder/xla_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,13 @@ struct XlaBuilderFriend {
static XlaOp BuildCollectivePermuteStart(
XlaBuilder* builder, XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id = std::nullopt);
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const bool inplace = false);
static XlaOp BuildCollectivePermuteStart(
XlaBuilder* builder, absl::Span<const XlaOp> operands,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const bool inplace = false);
static XlaOp BuildCollectivePermuteDone(XlaBuilder* builder, XlaOp operands,
const Shape& shape);

Expand Down Expand Up @@ -880,7 +886,14 @@ class XlaBuilder {
XlaOp CollectivePermute(
XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id = std::nullopt);
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const bool inplace = false);

XlaOp CollectivePermute(
absl::Span<const XlaOp> operands,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const bool inplace = false);

XlaOp ReplicaId();

Expand Down Expand Up @@ -1562,7 +1575,11 @@ class XlaBuilder {
friend XlaOp CollectivePermute(
XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id);
const std::optional<ChannelHandle>& channel_id, const bool inplace);
friend XlaOp MultiCollectivePermute(
absl::Span<const XlaOp> operands,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id, const bool inplace);
friend XlaOp ReplicaId(XlaBuilder* builder);
friend XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
absl::Span<const int64_t> window_dimensions,
Expand Down Expand Up @@ -1714,7 +1731,14 @@ class XlaBuilder {
XlaOp CollectivePermuteImpl(
XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id, bool async);
const std::optional<ChannelHandle>& channel_id, bool async,
const bool inplace);

XlaOp CollectivePermuteImpl(
absl::Span<const XlaOp> operands,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id, bool async,
const bool inplace);

XlaOp ConditionalImpl(
XlaOp branch_index,
Expand Down Expand Up @@ -2647,7 +2671,13 @@ XlaOp CollectiveBroadcast(
XlaOp CollectivePermute(
XlaOp operand,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id = std::nullopt);
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const bool inplace = false);
XlaOp MultiCollectivePermute(
absl::Span<const XlaOp> operands,
const std::vector<std::pair<int64_t, int64_t>>& source_target_pairs,
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const bool inplace = false);

// Enqueues an operation that returns the replica ID.
XlaOp ReplicaId(XlaBuilder* builder);
Expand Down
9 changes: 9 additions & 0 deletions xla/hlo/builder/xla_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,15 @@ TEST(XlaBuilderTest, CollectivePermute) {
EXPECT_EQ(GetRoot(*module)->opcode(), HloOpcode::kCollectivePermute);
}

TEST(XlaBuilderTest, CombinedCollectivePermute) {
XlaBuilder b(TestName());
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {5, 7}), "y");
MultiCollectivePermute({x, y}, {{0, 1}, {1, 2}, {2, 3}});
TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b));
EXPECT_EQ(GetRoot(*module)->opcode(), HloOpcode::kCollectivePermute);
}

TEST(XlaBuilderTest, GetDimensionSize) {
XlaBuilder b(TestName());
auto x =
Expand Down
Loading
Loading