Skip to content

Commit

Permalink
Cleanup. Refactor GetGatherScatterBatchParallelDims. No behavior change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698230884
  • Loading branch information
ZixuanJiang authored and Google-ML-Automation committed Nov 20, 2024
1 parent 229e8fb commit c738435
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions xla/hlo/utils/hlo_sharding_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2336,12 +2336,12 @@ std::optional<GatherScatterParallelDims> GetGatherScatterBatchParallelDims(
return orig_indices;
};
indices = findConcatenate(indices);

// Handle cases where we concatenate pieces of the indices one at a time.
if (indices->opcode() == HloOpcode::kConcatenate &&
indices->concatenate_dimension() == index_vector_dim) {
int concatenated_dims = 0;
for (int i = 0; i < indices->operand_count(); ++i) {
const HloInstruction* op = indices->operand(i);
for (const HloInstruction* op : indices->operands()) {
const int64_t num_indices_from_element =
op->shape().dimensions_size() > index_vector_dim
? op->shape().dimensions(index_vector_dim)
Expand All @@ -2367,32 +2367,33 @@ std::optional<GatherScatterParallelDims> GetGatherScatterBatchParallelDims(
index_parallel_in_dim.assign(num_indices_from_element, *maybe_iota_dim);
}
}

absl::InlinedVector<int64_t, 1> indices_parallel_dims;
absl::InlinedVector<int64_t, 1> operand_parallel_dims;
// Map the parallelizable dimension from the iota to the dimensions of the
// output and the operand. These dimensions are interconnected, but between
// operands and index they could have different spots in the shape because the
// position of the index dimension in the operand is determined by index_map.
for (int i = 0; i < index_parallel_in_dim.size(); ++i) {
int index_parallel_dim = index_parallel_in_dim[i];
if (index_parallel_dim == -1) {
for (int64_t i = 0; i < index_parallel_in_dim.size(); ++i) {
int64_t indices_parallel_dim = index_parallel_in_dim[i];
int64_t operand_parallel_dim = index_map[i];
if (indices_parallel_dim == -1) {
continue;
}
if (absl::c_linear_search(indices_parallel_dims, index_parallel_dim)) {
if (absl::c_linear_search(indices_parallel_dims, indices_parallel_dim)) {
return std::nullopt;
}
// Considered parallel only if the slice is of size 1 over the operand.
if (slice_sizes[index_map[i]] == 1) {
indices_parallel_dims.push_back(index_parallel_dim);
operand_parallel_dims.push_back(index_map[i]);
if (operand->shape().dimensions(operand_parallel_dims.back()) !=
indices->shape().dimensions(indices_parallel_dims.back())) {
if (slice_sizes[operand_parallel_dim] == 1) {
if (operand->shape().dimensions(operand_parallel_dim) !=
indices->shape().dimensions(indices_parallel_dim)) {
return std::nullopt;
}
} else {
index_parallel_in_dim[i] = -1;
indices_parallel_dims.push_back(indices_parallel_dim);
operand_parallel_dims.push_back(operand_parallel_dim);
}
}

if (!indices_parallel_dims.empty()) {
return GatherScatterParallelDims{indices_parallel_dims,
operand_parallel_dims};
Expand Down

0 comments on commit c738435

Please sign in to comment.