Skip to content

Commit

Permalink
[XLA:SPMD] Add HLO annotation to disable collective matmul in SPMD.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698271808
  • Loading branch information
seherellis authored and Google-ML-Automation committed Nov 20, 2024
1 parent c738435 commit 74da9de
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 2 deletions.
1 change: 1 addition & 0 deletions xla/service/spmd/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ cc_library(
"//xla:literal_util",
"//xla:protobuf_util",
"//xla:shape_util",
"//xla:side_effect_util",
"//xla:status_macros",
"//xla:types",
"//xla:util",
Expand Down
9 changes: 7 additions & 2 deletions xla/service/spmd/dot_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ limitations under the License.
#include "xla/service/spmd/spmd_partitioner_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/side_effect_util.h"
#include "xla/status_macros.h"
#include "xla/util.h"
#include "xla/window_util.h"
Expand Down Expand Up @@ -1905,8 +1906,12 @@ absl::StatusOr<HloInstruction*> PartitionBaseCase(
hlo.hlo()->opcode() == HloOpcode::kBitcast ||
hlo.hlo()->opcode() == HloOpcode::kTranspose;
};
bool should_skip_windowed_einsum = false;
if (options.disable_ag_rewrite_for_multiple_consumers) {
const auto& attrs = original_hlo->frontend_attributes().map();
bool should_skip_windowed_einsum =
attrs.contains(kXlaCollectiveMatmulAttr) &&
attrs.at(kXlaCollectiveMatmulAttr) == kXlaCollectiveMatmulNone;
if (!should_skip_windowed_einsum &&
options.disable_ag_rewrite_for_multiple_consumers) {
auto lhs_operand =
has_reshape_operand(lhs) ? lhs.hlo()->operand(0) : lhs.hlo();
auto rhs_operand =
Expand Down
24 changes: 24 additions & 0 deletions xla/service/spmd/spmd_partitioner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4975,6 +4975,30 @@ ENTRY entry {
op::Shape("f32[16,256,1024]")));
}

TEST_P(SpmdPartitioningTest, DisableWindowedEinsumWithUserAnnotation) {
absl::string_view hlo_string = R"(
HloModule module

ENTRY entry {
%p0 = f32[2048,2,3264]{2,1,0} parameter(0), sharding={devices=[1,1,2]0,1}
%p1 = f32[2,3264,2176]{2,1,0} parameter(1), sharding={devices=[2,1,1]0,1}
ROOT %dot.224 = f32[2048,2176]{1,0} dot(f32[2048,2,3264]{2,1,0} %p0, f32[2,3264,2176]{2,1,0} %p1), lhs_contracting_dims={1,2}, rhs_contracting_dims={0,1}, sharding={devices=[1,2]0,1}, frontend_attributes={_xla_collective_matmul="none"}
})";

TF_ASSERT_OK_AND_ASSIGN(
auto module,
PartitionComputation(hlo_string, /*num_devices=*/2,
/*conv_halo_exchange_always_on_lhs=*/true,
/*choose_faster_windowed_einsum=*/false,
/*unroll_windowed_einsum=*/false,
/*bidirectional_windowed_einsum=*/false,
/*threshold_for_windowed_einsum_mib=*/0));
ASSERT_FALSE(absl::c_any_of(module->entry_computation()->instructions(),
[](const HloInstruction* inst) {
return inst->opcode() == HloOpcode::kWhile;
}));
}

TEST_P(SpmdPartitioningTest, EinsumBatchPartitioned) {
absl::string_view hlo_string = R"(
HloModule module
Expand Down
2 changes: 2 additions & 0 deletions xla/side_effect_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ const char kXlaCollectiveMatmulRhsAg[] = "rhs_ag";

const char kXlaCollectiveMatmulRs[] = "rs";

const char kXlaCollectiveMatmulNone[] = "none";

const char kXlaMultiRecvCountAttr[] = "_xla_multi_recv_count";

} // namespace xla
1 change: 1 addition & 0 deletions xla/side_effect_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ extern const char kXlaCollectiveMatmulAttr[];
extern const char kXlaCollectiveMatmulLhsAg[];
extern const char kXlaCollectiveMatmulRhsAg[];
extern const char kXlaCollectiveMatmulRs[];
extern const char kXlaCollectiveMatmulNone[];

// XLA frontend attribute for specifying the number of sends this recv should
// match.
Expand Down

0 comments on commit 74da9de

Please sign in to comment.