From b056d3da25c18169c14018a1e953112852a53453 Mon Sep 17 00:00:00 2001 From: Abhinav Gunjal Date: Wed, 24 Jul 2024 09:46:07 -0700 Subject: [PATCH] Update all_to_all op interpreter to fix issue#2433 (#2457) fixes: https://github.com/openxla/stablehlo/issues/2433 ``` scattered_parts...@receiver = [split_parts...@sender[receiver_index] for sender in process_group] where receiver_index = process_group.index(receiver). ``` each `sender/process` from `process_group` contributes a `split_part` to the `scattered_parts`. This PR ensures the order of `split_part` in `scattered_parts` matches with the order of the `sender/processes` within the `process_group`. --- stablehlo/reference/Ops.cpp | 19 +++++++------- stablehlo/tests/interpret/all_to_all.mlir | 32 +++++++++++++++++++++++ 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index 143b38df92f..1e6769ebee8 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -1198,20 +1198,21 @@ Tensor allToAllOp(const Tensor &operand, Axis splitDimension, if (channelId > 0) processGroups = process->crossPartition(replicaGroups); auto processGroup = processGroups.findGroup(process->getId()); - if (!processGroup) + if (!processGroup.has_value()) llvm::report_fatal_error(invalidArgument( "Failed to find process group with process_id: (%d, %d)", process->getId().replicaId, process->getId().partitionId)); - auto groupOperands = process->rendezvous(*processGroup, channelId, {operand}) - .getSortedTensors(); + auto groupOperands = + process->rendezvous(processGroup.value(), channelId, {operand}); + + auto receiverIndex = llvm::find(processGroup.value(), process->getId()) - + processGroup->begin(); SmallVector scatteredParts; - for (const auto &groupOperand : groupOperands) { - auto splitParts = split(groupOperand.front(), splitCount, splitDimension, - operand.getType().getContext()); - for (auto [i, processId] : llvm::enumerate(*processGroup)) - if (processId == process->getId()) - scatteredParts.push_back(splitParts[i]); + for (const auto &sender : processGroup.value()) { + auto splitParts = split(groupOperands.lookup(sender).front(), splitCount, + splitDimension, operand.getType().getContext()); + scatteredParts.push_back(splitParts[receiverIndex]); } return concatenateOp(scatteredParts, concatDimension, resultType); } diff --git a/stablehlo/tests/interpret/all_to_all.mlir b/stablehlo/tests/interpret/all_to_all.mlir index 223030acdf4..08da78d09e5 100644 --- a/stablehlo/tests/interpret/all_to_all.mlir +++ b/stablehlo/tests/interpret/all_to_all.mlir @@ -32,6 +32,38 @@ module @cross_replica { // ----- +module @cross_replica_issue_2433 { + func.func @all_to_all(%operand : tensor<2x4xi64>) -> tensor<4x2xi64> { + %result = "stablehlo.all_to_all"(%operand) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 2 : i64, + replica_groups = dense<[[1, 0]]> : tensor<1x2xi64> + } : (tensor<2x4xi64>) -> tensor<4x2xi64> + return %result : tensor<4x2xi64> + } + func.func @main() { + %inputs0 = stablehlo.constant dense<[[1, 2, 3, 4], + [5, 6, 7, 8]]> : tensor<2x4xi64> + %inputs1 = stablehlo.constant dense<[[9, 10, 11, 12], + [13, 14, 15, 16]]> : tensor<2x4xi64> + %results:2 = "interpreter.run_parallel"(%inputs0, %inputs1) { + programs=[[@all_to_all], [@all_to_all]] + } : (tensor<2x4xi64>, tensor<2x4xi64>) -> (tensor<4x2xi64>, tensor<4x2xi64>) + check.expect_eq_const %results#0, dense<[[11, 12], + [15, 16], + [3, 4], + [7, 8]]> : tensor<4x2xi64> + check.expect_eq_const %results#1, dense<[[9, 10], + [13, 14], + [1, 2], + [5, 6]]> : tensor<4x2xi64> + func.return + } +} + +// ----- + module @cross_partition { func.func @all_to_all(%operand : tensor<2x4xi64>) -> tensor<4x2xi64> { %result = "stablehlo.all_to_all"(%operand) {