Skip to content

Commit

Permalink
Cleanup. Refactor gather_scatter_handler. Remove unused code. Replace…
Browse files Browse the repository at this point in the history
… sort with stable_sort.

PiperOrigin-RevId: 697870443
  • Loading branch information
ZixuanJiang authored and Google-ML-Automation committed Nov 19, 2024
1 parent d4c65bf commit a96fed6
Showing 1 changed file with 10 additions and 41 deletions.
51 changes: 10 additions & 41 deletions xla/service/spmd/gather_scatter_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -847,16 +847,11 @@ std::vector<decltype(PartitionGather)*> GatherPartitionMethodsOrderedByCost(
partition_method, std::make_pair(memory_cost, communication_cost));
ordered_partition_methods.push_back(partition_method);
}
absl::c_sort(ordered_partition_methods, [&](decltype(PartitionGather)* lhs,
decltype(PartitionGather)* rhs) {
auto [lhs_memory_cost, lhs_communication_cost] =
partition_method_costs[lhs];
auto [rhs_memory_cost, rhs_communication_cost] =
partition_method_costs[rhs];
return lhs_memory_cost != rhs_memory_cost
? lhs_memory_cost < rhs_memory_cost
: lhs_communication_cost < rhs_communication_cost;
});
absl::c_stable_sort(
ordered_partition_methods,
[&](decltype(PartitionGather)* lhs, decltype(PartitionGather)* rhs) {
return partition_method_costs[lhs] < partition_method_costs[rhs];
});
VLOG(5) << "Gather partitioning methods(ordered by cost):";
for (auto partition_method : ordered_partition_methods) {
VLOG(5) << " "
Expand Down Expand Up @@ -940,27 +935,6 @@ int64_t ShapeSizeInBytesSum(absl::Span<const T> operands, F&& get_shape) {
});
}

int64_t BaseShapeSizeSum(absl::Span<const PartitionedHlo> phlos) {
return ShapeSizeInBytesSum(
phlos, [](const PartitionedHlo& phlo) { return phlo.base_shape(); });
}

int64_t BaseShapeSizeSum(absl::Span<const PartitionedHlo> phlos,
const HloSharding& sharding) {
return ShapeSizeInBytesSum(phlos, [&sharding](const PartitionedHlo& phlo) {
return MakePartitionedShape(phlo.base_shape(), sharding);
});
}

int64_t ShapeSizeSum(absl::Span<const PartitionedHlo> phlos) {
return ShapeSizeInBytesSum(
phlos, [](const PartitionedHlo& phlo) { return phlo.hlo()->shape(); });
}

int64_t ShapeSizeSum(absl::Span<const Shape> shapes) {
return ShapeSizeInBytesSum(shapes, [](const Shape& shape) { return shape; });
}

Shape MaybeMakeTupleShape(absl::Span<const HloInstruction* const> hlos) {
if (hlos.size() == 1) {
return hlos[0]->shape();
Expand Down Expand Up @@ -1702,16 +1676,11 @@ std::vector<decltype(PartitionScatter)*> ScatterPartitionMethodsOrderedByCost(
partition_method, std::make_pair(memory_cost, communication_cost));
ordered_partition_methods.push_back(partition_method);
}
absl::c_sort(ordered_partition_methods, [&](decltype(PartitionScatter)* lhs,
decltype(PartitionScatter)* rhs) {
auto [lhs_memory_cost, lhs_communication_cost] =
partition_method_costs[lhs];
auto [rhs_memory_cost, rhs_communication_cost] =
partition_method_costs[rhs];
return lhs_memory_cost != rhs_memory_cost
? lhs_memory_cost < rhs_memory_cost
: lhs_communication_cost < rhs_communication_cost;
});
absl::c_stable_sort(
ordered_partition_methods,
[&](decltype(PartitionScatter)* lhs, decltype(PartitionScatter)* rhs) {
return partition_method_costs[lhs] < partition_method_costs[rhs];
});
VLOG(5) << "Scatter partitioning methods(ordered by cost):";
for (auto partition_method : ordered_partition_methods) {
VLOG(5) << " "
Expand Down

0 comments on commit a96fed6

Please sign in to comment.