Skip to content

Commit

Permalink
Update all_to_all op interpreter to fix issue#2433 (#2457)
Browse files Browse the repository at this point in the history
fixes: #2433


```
scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group] where receiver_index = process_group.index(receiver).
```
each `sender/process` from `process_group` contributes a `split_part` to
the `scattered_parts`. This PR ensures the order of `split_part` in
`scattered_parts` matches with the order of the `sender/processes`
within the `process_group`.
  • Loading branch information
abhigunj authored Jul 24, 2024
1 parent eba821a commit b056d3d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 9 deletions.
19 changes: 10 additions & 9 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1198,20 +1198,21 @@ Tensor allToAllOp(const Tensor &operand, Axis splitDimension,
if (channelId > 0) processGroups = process->crossPartition(replicaGroups);

auto processGroup = processGroups.findGroup(process->getId());
if (!processGroup)
if (!processGroup.has_value())
llvm::report_fatal_error(invalidArgument(
"Failed to find process group with process_id: (%d, %d)",
process->getId().replicaId, process->getId().partitionId));
auto groupOperands = process->rendezvous(*processGroup, channelId, {operand})
.getSortedTensors();
auto groupOperands =
process->rendezvous(processGroup.value(), channelId, {operand});

auto receiverIndex = llvm::find(processGroup.value(), process->getId()) -
processGroup->begin();

SmallVector<Tensor> scatteredParts;
for (const auto &groupOperand : groupOperands) {
auto splitParts = split(groupOperand.front(), splitCount, splitDimension,
operand.getType().getContext());
for (auto [i, processId] : llvm::enumerate(*processGroup))
if (processId == process->getId())
scatteredParts.push_back(splitParts[i]);
for (const auto &sender : processGroup.value()) {
auto splitParts = split(groupOperands.lookup(sender).front(), splitCount,
splitDimension, operand.getType().getContext());
scatteredParts.push_back(splitParts[receiverIndex]);
}
return concatenateOp(scatteredParts, concatDimension, resultType);
}
Expand Down
32 changes: 32 additions & 0 deletions stablehlo/tests/interpret/all_to_all.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,38 @@ module @cross_replica {

// -----

module @cross_replica_issue_2433 {
func.func @all_to_all(%operand : tensor<2x4xi64>) -> tensor<4x2xi64> {
%result = "stablehlo.all_to_all"(%operand) {
split_dimension = 1 : i64,
concat_dimension = 0 : i64,
split_count = 2 : i64,
replica_groups = dense<[[1, 0]]> : tensor<1x2xi64>
} : (tensor<2x4xi64>) -> tensor<4x2xi64>
return %result : tensor<4x2xi64>
}
func.func @main() {
%inputs0 = stablehlo.constant dense<[[1, 2, 3, 4],
[5, 6, 7, 8]]> : tensor<2x4xi64>
%inputs1 = stablehlo.constant dense<[[9, 10, 11, 12],
[13, 14, 15, 16]]> : tensor<2x4xi64>
%results:2 = "interpreter.run_parallel"(%inputs0, %inputs1) {
programs=[[@all_to_all], [@all_to_all]]
} : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>)
check.expect_eq_const %results#0, dense<[[11, 12],
[15, 16],
[3, 4],
[7, 8]]> : tensor<4x2xi64>
check.expect_eq_const %results#1, dense<[[9, 10],
[13, 14],
[1, 2],
[5, 6]]> : tensor<4x2xi64>
func.return
}
}

// -----

module @cross_partition {
func.func @all_to_all(%operand : tensor<2x4xi64>) -> tensor<4x2xi64> {
%result = "stablehlo.all_to_all"(%operand) {
Expand Down

0 comments on commit b056d3d

Please sign in to comment.