Skip to content

Commit

Permalink
[XLA:CPU] Add a direct implementation of ReduceScatter, instead of lo…
Browse files Browse the repository at this point in the history
…wering ReduceScatter to AllReduce+DynamicSlice.

PiperOrigin-RevId: 586424242
  • Loading branch information
hawkinsp authored and copybara-github committed Nov 29, 2023
1 parent 954527c commit 69f26cf
Show file tree
Hide file tree
Showing 10 changed files with 381 additions and 107 deletions.
3 changes: 2 additions & 1 deletion xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ cc_library(
"//xla/service:optimization_barrier_expander",
"//xla/service:qr_expander",
"//xla/service:reduce_decomposer",
"//xla/service:reduce_scatter_decomposer",
"//xla/service:reshape_decomposer",
"//xla/service:reshape_mover",
"//xla/service:result_caster",
Expand Down Expand Up @@ -877,6 +876,7 @@ cc_library(
"//xla:shape_util",
"//xla:statusor",
"//xla:types",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/service:collective_ops_utils",
"//xla/service:computation_placer",
Expand All @@ -886,6 +886,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
Expand Down
6 changes: 6 additions & 0 deletions xla/service/cpu/collectives_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ class CollectivesCommunicator {
virtual absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes,
const void* input_buffer, void* output_buffer,
absl::Duration timeout) = 0;

// Performs a reduce-scatter
virtual absl::Status ReduceScatter(
const RendezvousKey& key, ReductionKind reduction_kind,
PrimitiveType element_type, size_t chunk_elems, const void* input_buffer,
void* output_buffer, absl::Duration timeout) = 0;
};

class CollectivesInterface {
Expand Down
2 changes: 0 additions & 2 deletions xla/service/cpu/cpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ limitations under the License.
#include "xla/service/optimization_barrier_expander.h"
#include "xla/service/qr_expander.h"
#include "xla/service/reduce_decomposer.h"
#include "xla/service/reduce_scatter_decomposer.h"
#include "xla/service/reshape_decomposer.h"
#include "xla/service/reshape_mover.h"
#include "xla/service/result_caster.h"
Expand Down Expand Up @@ -685,7 +684,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pipeline.AddPass<EighExpander>();
pipeline.AddPass<TriangularSolveExpander>();
pipeline.AddPass<AllToAllDecomposer>();
pipeline.AddPass<ReduceScatterDecomposer>();
pipeline.AddPass<StochasticConvertDecomposer>();

// Inline computations with a single call site.
Expand Down
7 changes: 7 additions & 0 deletions xla/service/cpu/cpu_layout_assignment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ Status CpuLayoutAssignment::AddBackendConstraints(
const HloInstruction* op = instruction->operand(*op_idx);
TF_RETURN_IF_ERROR(
SetOperandLayout(ColMajorShape(op->shape()), instruction, *op_idx));
} else if (instruction->opcode() == HloOpcode::kReduceScatter) {
// XLA:CPU can only support reduce-scatter where the scatter dimension
// is the most major dimension in the layout.
auto ars = Cast<HloReduceScatterInstruction>(instruction);
TF_RETURN_IF_ERROR(SetInstructionLayout(
ShapeUtil::MoveDimToMajor(ars->shape(), ars->scatter_dimension()),
ars));
} else if (instruction->opcode() == HloOpcode::kAllGather) {
// XLA:CPU can only support all-gathers where the gather dimension is the
// most major dimension in the layout.
Expand Down
77 changes: 65 additions & 12 deletions xla/service/cpu/cpu_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/base/attributes.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
Expand All @@ -46,6 +49,7 @@ limitations under the License.
#include "xla/statusor.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/status.h"
Expand Down Expand Up @@ -143,6 +147,8 @@ extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce";
extern const char* const kAllGatherSymbolName = "__xla_cpu_runtime_AllGather";
extern const char* const kReduceScatterSymbolName =
"__xla_cpu_runtime_ReduceScatter";
extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll";
extern const char* const kCollectivePermuteSymbolName =
"__xla_cpu_runtime_CollectivePermute";
Expand Down Expand Up @@ -315,6 +321,19 @@ CollectivesInterface* GetInProcessCollectivesImpl() {

absl::Duration DefaultCollectiveTimeout() { return absl::InfiniteDuration(); }

absl::StatusOr<int> RankInGlobalDevices(
absl::Span<GlobalDeviceId const> devices, GlobalDeviceId device) {
auto it = absl::c_find(devices, device);
if (it == devices.end()) {
return InvalidArgument(
"Device %d not present in global devices %s.", device.value(),
absl::StrJoin(devices, ", ", [](std::string* out, GlobalDeviceId id) {
absl::StrAppend(out, id.value());
}));
}
return std::distance(devices.begin(), it);
}

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY
void AllToAllImpl(const ExecutableRunOptions* run_options,
int32_t channel_id_present, int64_t op_id,
Expand All @@ -331,9 +350,7 @@ void AllToAllImpl(const ExecutableRunOptions* run_options,
GetRendezvousKey(run_options, device, group, channel_id_present,
/*use_global_device_ids=*/std::nullopt, op_id);

auto it = absl::c_find(rendezvous_key.global_devices, device);
CHECK(it != rendezvous_key.global_devices.end());
int rank = std::distance(rendezvous_key.global_devices.begin(), it);
int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value();

CollectivesInterface* collectives = GetInProcessCollectivesImpl();

Expand Down Expand Up @@ -361,9 +378,7 @@ void AllGatherImpl(const ExecutableRunOptions* run_options,
GetRendezvousKey(run_options, device, group, channel_id_present,
/*use_global_device_ids=*/std::nullopt, op_id);

auto it = absl::c_find(rendezvous_key.global_devices, device);
CHECK(it != rendezvous_key.global_devices.end());
int rank = std::distance(rendezvous_key.global_devices.begin(), it);
int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value();

CollectivesInterface* collectives = GetInProcessCollectivesImpl();

Expand All @@ -374,6 +389,35 @@ void AllGatherImpl(const ExecutableRunOptions* run_options,
DefaultCollectiveTimeout()));
}

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY
void ReduceScatterImpl(const ExecutableRunOptions* run_options,
const void* replica_groups_str,
int32_t replica_groups_str_size,
int32_t channel_id_present, int64_t op_id,
int32_t reduction_kind, int32_t element_type,
int64_t chunk_elems, void* input_buffer,
void* output_buffer) {
GlobalDeviceId device(GetDeviceOrdinal(run_options));
std::string_view replica_groups_serialized(
static_cast<const char*>(replica_groups_str), replica_groups_str_size);
std::vector<ReplicaGroup> group =
ParseReplicaGroupsOnly(replica_groups_serialized).value();
RendezvousKey rendezvous_key =
GetRendezvousKey(run_options, device, group, channel_id_present,
/*use_global_device_ids=*/std::nullopt, op_id);

int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value();

CollectivesInterface* collectives = GetInProcessCollectivesImpl();

auto communicator =
collectives->GetCommunicator(rendezvous_key.global_devices, rank).value();
TF_CHECK_OK(communicator->ReduceScatter(
rendezvous_key, static_cast<ReductionKind>(reduction_kind),
static_cast<PrimitiveType>(element_type), chunk_elems, input_buffer,
output_buffer, DefaultCollectiveTimeout()));
}

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY
void AllReduceImpl(const ExecutableRunOptions* run_options,
const void* replica_groups_str,
Expand All @@ -399,9 +443,7 @@ void AllReduceImpl(const ExecutableRunOptions* run_options,
CHECK((num_buffers > 1 && shape.IsTuple()) ||
(num_buffers == 1 && LayoutUtil::IsDenseArray(shape)));

auto it = absl::c_find(rendezvous_key.global_devices, device);
CHECK(it != rendezvous_key.global_devices.end());
int rank = std::distance(rendezvous_key.global_devices.begin(), it);
int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value();

CollectivesInterface* collectives = GetInProcessCollectivesImpl();

Expand Down Expand Up @@ -450,9 +492,7 @@ void CollectivePermuteImpl(const ExecutableRunOptions* run_options,
GetRendezvousKey(run_options, device, {}, channel_id_present,
/*use_global_device_ids=*/std::nullopt, op_id);

auto it = absl::c_find(rendezvous_key.global_devices, device);
CHECK(it != rendezvous_key.global_devices.end());
int rank = std::distance(rendezvous_key.global_devices.begin(), it);
int rank = RankInGlobalDevices(rendezvous_key.global_devices, device).value();

CollectivesInterface* collectives = GetInProcessCollectivesImpl();

Expand Down Expand Up @@ -542,6 +582,19 @@ void __xla_cpu_runtime_AllGather(const xla::ExecutableRunOptions* run_options,
run_options, channel_id_present, op_id, replica_groups_str,
replica_groups_str_size, buffer_size, source_buffer, destination_buffer);
}

void __xla_cpu_runtime_ReduceScatter(
const xla::ExecutableRunOptions* run_options,
const void* replica_groups_str, int32_t replica_groups_str_size,
int32_t channel_id_present, int64_t op_id, int32_t reduction_kind,
int32_t element_type, int64_t chunk_elems, void* input_buffer,
void* output_buffer) {
return xla::cpu::runtime::ReduceScatterImpl(
run_options, replica_groups_str, replica_groups_str_size,
channel_id_present, op_id, reduction_kind, element_type, chunk_elems,
input_buffer, output_buffer);
}

void __xla_cpu_runtime_AllReduce(const xla::ExecutableRunOptions* run_options,
const void* replica_groups_str,
int32_t replica_groups_str_size,
Expand Down
8 changes: 8 additions & 0 deletions xla/service/cpu/cpu_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ extern const char* const kTracingStartSymbolName;
extern const char* const kTracingEndSymbolName;
extern const char* const kAllToAllSymbolName;
extern const char* const kAllGatherSymbolName;
extern const char* const kReduceScatterSymbolName;
extern const char* const kOneDnnMatMulSymbolName;

// All symbol names for XLA CPU runtime functions need to start with this
Expand Down Expand Up @@ -202,6 +203,13 @@ extern void __xla_cpu_runtime_AllGather(
int32_t replica_groups_str_size, int64_t buffer_size, void* source_buffer,
void* destination_buffer);

void __xla_cpu_runtime_ReduceScatter(
const xla::ExecutableRunOptions* run_options,
const void* replica_groups_str, int32_t replica_groups_str_size,
int32_t channel_id_present, int64_t op_id, int32_t reduction_kind,
int32_t element_type, int64_t chunk_elems, void* input_buffer,
void* output_buffer);

// Write the partition ID into the output buffer.
extern void __xla_cpu_runtime_PartitionId(
const xla::ExecutableRunOptions* run_options, void* output_buffer);
Expand Down
Loading

0 comments on commit 69f26cf

Please sign in to comment.