Skip to content

Commit

Permalink
[RFC] Standardize Collective (AllGatherOp, AllReduceOp, AllToAllOp) o…
Browse files Browse the repository at this point in the history
…ps to support variadic operand/result (#2099)

This RFC proposes standardizing collective (AllGatherOp, AllReduceOp,
AllToAllOp) ops to enable support for `multi-operand` and `multi-result`

Please review and provide you feedback.
  • Loading branch information
abhigunj authored Jun 26, 2024
1 parent 2066701 commit 6b69e21
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 65 deletions.
137 changes: 72 additions & 65 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -682,8 +682,8 @@ it only exists to establish data dependencies from `result` to `inputs`.
#### Semantics

Within each process group in the StableHLO process grid, concatenates the values
of the `operand` tensor from each process along `all_gather_dim` and produces a
`result` tensor.
of the `operands` tensors from each process along `all_gather_dim` and produces
`results` tensors.

The operation splits the StableHLO process grid into `process_groups` which is
defined as follows:
Expand All @@ -697,57 +697,61 @@ defined as follows:

Afterwards, within each `process_group`:

* `operands@receiver = [operand@sender for sender in process_group]` for all
* `operands...@receiver = [operand@sender for sender in process_group]` for all
`receiver` in `process_group`.
* `result@process = concatenate(operands@process, all_gather_dim)` for all
* `results...@process = concatenate(operands...@process, all_gather_dim)` for all
`process` in `process_group`.

#### Inputs

| Label | Name | Type | Constraints |
|-------|-------------------------|----------------------------------------------|-------------|
| (I1) | `operand` | tensor or per-tensor quantized tensor | (C1), (C6) |
| (I2) | `all_gather_dim` | constant of type `si64` | (C1), (C6) |
| (I3) | `replica_groups` | 2-dimensional tensor constant of type `si64` | (C2-C4) |
| (I4) | `channel_id` | constant of type `si64` | (C5) |
| (I5) | `use_global_device_ids` | constant of type `i1` | (C5) |
| Label | Name | Type | Constraints |
|-------|-------------------------|-------------------------------------------------------------|-------------|
| (I1) | `operands` | variadic number of tensors or per-tensor quantized tensors | (C1), (C6) |
| (I2) | `all_gather_dim` | constant of type `si64` | (C1), (C6) |
| (I3) | `replica_groups` | 2-dimensional tensor constant of type `si64` | (C2-C4) |
| (I4) | `channel_id` | constant of type `si64` | (C5) |
| (I5) | `use_global_device_ids` | constant of type `i1` | (C5) |

#### Outputs

| Name | Type | Constraints |
|----------|---------------------------------------|-------------|
| `result` | tensor or per-tensor quantized tensor | (C6) |
| Name | Type | Constraints |
|-----------|------------------------------------------------------------|-------------|
| `results` | variadic number of tensors or per-tensor quantized tensors | (C6) |

#### Constraints

* (C1) `0 <= all_gather_dim < rank(operand)`.
* (C1) `0 <= all_gather_dim < rank(operands...)`.
* (C2) `is_unique(replica_groups)`.
* (C3) `size(replica_groups)` is defined as:
* `num_replicas` if `cross_replica` is used.
* `num_replicas` if `cross_replica_and_partition` is used.
* `num_processes` if `flattened_ids` is used.
* (C4) `0 <= replica_groups < size(replica_groups)`.
* (C5) If `use_global_device_ids = true`, then `channel_id > 0`.
* (C6) `type(result) = type(operand)` except:
* `dim(result, all_gather_dim) =
dim(operand, all_gather_dim) * dim(process_groups, 1)`.
* (C6) `type(results...) = type(operands...)` except:
* `dim(results..., all_gather_dim) =
dim(operands..., all_gather_dim) * dim(process_groups, 1)`.

#### Examples

```mlir
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [[1, 2], [3, 4]]
// %operand@(1, 0): [[5, 6], [7, 8]]
%result = "stablehlo.all_gather"(%operand) {
// %operand0@(0, 0): [[1, 2], [3, 4]]
// %operand0@(1, 0): [[5, 6], [7, 8]]
// %operand1@(0, 0): [[11, 12], [13, 14]]
// %operand1@(1, 0): [[15, 16], [17, 18]]
%result = "stablehlo.all_gather"(%operand0, %operand1) {
all_gather_dim = 1 : i64,
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
// channel_id = 0
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
// use_global_device_ids = false
} : (tensor<2x2xi64>) -> tensor<2x4xi64>
// %result@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
} : (tensor<2x2xi64>, tensor<2x2xi64>) -> (tensor<2x4xi64>, tensor<2x4xi64>)
// %result0@(0, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result0@(1, 0): [[1, 2, 5, 6], [3, 4, 7, 8]]
// %result1@(0, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
// %result1@(1, 0): [[11, 12, 15, 16], [13, 14, 17, 18]]
```

&nbsp;[More Examples](https://github.com/openxla/stablehlo/tree/main/stablehlo/tests/interpret/all_gather.mlir)
Expand All @@ -757,8 +761,8 @@ Afterwards, within each `process_group`:
#### Semantics

Within each process group in the StableHLO process grid, applies a reduction
function `computation` to the values of the `operand` tensor from each process
and produces a `result` tensor.
function `computation` to the values of the `operands` tensors from each process
and produces `results` tensors.

The operation splits the StableHLO process grid into `process_groups` which is
defined as follows:
Expand All @@ -772,29 +776,29 @@ defined as follows:

Afterwards, within each `process_group`:

* `result@process[result_index] = exec(schedule)` for some binary tree
* `results...@process[result_index] = exec(schedule)` for some binary tree
`schedule` where:
* `exec(node)` = `computation(exec(node.left), exec(node.right))`.
* `exec(leaf)` = `leaf.value`.
* `schedule` is an implementation-defined binary tree whose in-order
traversal is `to_destination_type(operands@process_group...[result_index],
traversal is `to_destination_type(operands...@process_group...[result_index],
type(func_inputs(computation)[0]))`.

#### Inputs

| Label | Name | Type | Constraints |
|-------|-------------------------|------------------------------------------------------------------|-------------|
| (I1) | `operand` | tensor or per-tensor quantized tensor | (C5), (C6) |
| (I1) | `operands` | variadic number of tensors or per-tensor quantized tensors | (C5), (C6) |
| (I2) | `replica_groups` | variadic number of 1-dimensional tensor constants of type `si64` | (C1-C3) |
| (I3) | `channel_id` | constant of type `si64` | (C4) |
| (I4) | `use_global_device_ids` | constant of type `i1` | (C4) |
| (I5) | `computation` | function | (C5) |

#### Outputs

| Name | Type | Constraints |
|----------|---------------------------------------|-------------|
| `result` | tensor or per-tensor quantized tensor | (C6-C7) |
| Name | Type | Constraints |
|-----------|-------------------------------------------------------------|-------------|
| `results` | variadic number of tensors or per-tensor quantized tensors | (C6-C7) |

#### Constraints

Expand All @@ -807,26 +811,30 @@ Afterwards, within each `process_group`:
* (C4) If `use_global_device_ids = true`, then `channel_id > 0`.
* (C5) `computation` has type `(tensor<E>, tensor<E>) -> (tensor<E>)` where
`is_promotable(element_type(operand), E)`.
* (C6) `shape(result) = shape(operand)`.
* (C7) `element_type(result) = E`.
* (C6) `shape(results...) = shape(operands...)`.
* (C7) `element_type(results...) = E`.

#### Examples

```mlir
// num_replicas: 2
// num_partitions: 1
// %operand@(0, 0): [1, 2, 3, 4]
// %operand@(1, 0): [5, 6, 7, 8]
%result = "stablehlo.all_reduce"(%operand) ({
// %operand0@(0, 0): [1, 2, 3, 4]
// %operand0@(1, 0): [5, 6, 7, 8]
// %operand1@(0, 0): [9, 10, 11, 12]
// %operand1@(1, 0): [13, 14, 15, 16]
%result = "stablehlo.all_reduce"(%operand0, %operand0) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>,
channel_handle = #stablehlo.channel_handle<handle = 0, type = 0>
} : (tensor<4xi64>) -> tensor<4xi64>
// %result@(0, 0): [6, 8, 10, 12]
// %result@(1, 0): [6, 8, 10, 12]
} : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>, tensor<4xi64>)
// %result0@(0, 0): [6, 8, 10, 12]
// %result0@(1, 0): [6, 8, 10, 12]
// %result1@(0, 0): [22, 24, 26, 28]
// %result1@(1, 0): [22, 24, 26, 28]
```

&nbsp;[More Examples](https://github.com/openxla/stablehlo/tree/main/stablehlo/tests/interpret/all_reduce.mlir)
Expand All @@ -838,10 +846,9 @@ Afterwards, within each `process_group`:
![all_to_all](images/spec/all_to_all.svg)

Within each process group in the StableHLO process grid, splits the values of
the `operand` tensor along `split_dimension` into parts, scatters the split
the `operands` tensors along `split_dimension` into parts, scatters the split
parts between the processes, concatenates the scattered parts along
`concat_dimension` and produces a `result` tensor.

`concat_dimension` and produces `results` tensors.
The operation splits the StableHLO process grid into `process_groups` which is
defined as follows:

Expand All @@ -850,48 +857,48 @@ defined as follows:

Afterwards, within each `process_group`:

* `split_parts@sender = split(operand@sender, split_count, split_dimension)`
* `split_parts...@sender = split(operands...@sender, split_count, split_dimension)`
for all `sender` in `process_group`.
* `scattered_parts@receiver = [split_parts@sender[receiver_index] for
* `scattered_parts...@receiver = [split_parts...@sender[receiver_index] for
sender in process_group]` where
`receiver_index = process_group.index(receiver)`.
* `result@process = concatenate(scattered_parts@process, concat_dimension)`.
* `results...@process = concatenate(scattered_parts...@process, concat_dimension)`.

#### Inputs

| Label | Name | Type | Constraints |
|-------|--------------------|----------------------------------------------|------------------------|
| (I1) | `operand` | tensor or per-tensor quantized tensor | (C1-C3), (C9) |
| (I2) | `split_dimension` | constant of type `si64` | (C1), (C2), (C9) |
| (I3) | `concat_dimension` | constant of type `si64` | (C3), (C9) |
| (I4) | `split_count` | constant of type `si64` | (C2), (C4), (C8), (C9) |
| (I5) | `replica_groups` | 2-dimensional tensor constant of type `si64` | (C5-C8) |
| (I6) | `channel_id` | constant of type `si64` | |
| Label | Name | Type | Constraints |
|-------|--------------------|--------------------------------------------------------------|------------------------|
| (I1) | `operands` | variadic number of tensors or per-tensor quantized tensors | (C1-C3), (C9) |
| (I2) | `split_dimension` | constant of type `si64` | (C1), (C2), (C9) |
| (I3) | `concat_dimension` | constant of type `si64` | (C3), (C9) |
| (I4) | `split_count` | constant of type `si64` | (C2), (C4), (C8), (C9) |
| (I5) | `replica_groups` | 2-dimensional tensor constant of type `si64` | (C5-C8) |
| (I6) | `channel_id` | constant of type `si64` | |

#### Outputs

| Name | Type | Constraints |
|----------|---------------------------------------|-------------|
| `result` | tensor or per-tensor quantized tensor | (C9) |
| Name | Type | Constraints |
|-----------|-------------------------------------------------------------|-------------|
| `results` | variadic number of tensors or per-tensor quantized tensors | (C9) |

#### Constraints

* (C1) `0 <= split_dimension < rank(operand)`.
* (C2) `dim(operand, split_dimension) % split_count = 0`.
* (C3) `0 <= concat_dimension < rank(operand)`.
* (C1) `0 <= split_dimension < rank(operands...)`.
* (C2) `dim(operands..., split_dimension) % split_count = 0`.
* (C3) `0 <= concat_dimension < rank(operands...)`.
* (C4) `0 < split_count`.
* (C5) `is_unique(replica_groups)`.
* (C6) `size(replica_groups)` is defined as:
* `num_replicas` if `cross_replica` is used.
* `num_partitions` if `cross_partition` is used.
* (C7) `0 <= replica_groups < size(replica_groups)`.
* (C8) `dim(replica_groups, 1) = split_count`.
* (C9) `type(result) = type(operand)` except, if `split_dimension !=
* (C9) `type(results...) = type(operands...)` except, if `split_dimension !=
concat_dimension`:
* `dim(result, split_dimension) =
dim(operand, split_dimension) / split_count`.
* `dim(result, concat_dimension) =
dim(operand, concat_dimension) * split_count`.
* `dim(results..., split_dimension) =
dim(operands..., split_dimension) / split_count`.
* `dim(results..., concat_dimension) =
dim(operands..., concat_dimension) * split_count`.

#### Examples

Expand Down
34 changes: 34 additions & 0 deletions rfcs/20240312-standardize-collective-ops.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# [RFC] Standardize collective ops to support variadic operand/result

Status: Review<br/>
Initial version: 03/12/2024<br/>
Last updated: 03/15/2024<br/>
Discussion thread: [GitHub](https://github.com/openxla/stablehlo/pull/2099)

## Motivation

Several features have been added to MHLO in the past year, which frameworks want
to leverage and members of the community have made requests for them as well.
This includes: feature to support variadic operands/results for collective
(`AllGatherOp`,`AllReduceOp`, `AllToAllOp`) ops.

We propose adding this feature to the StableHLO spec so they can be used by the community.
StableHLO collective ops support is currently limited to **single-operand** and **single-result**.
[MHLO collective ops](https://github.com/tensorflow/mlir-hlo/blob/master/mhlo/IR/hlo_ops.td)
support
**multi-operand** and **multi-result** which is in sync with multi-operand and
multi-result XLA semantics
([`all_reduce`](https://openxla.org/xla/operation_semantics#allreduce),
[`all_gather`](https://openxla.org/xla/operation_semantics#allgather) and
[`all_to_all`](https://openxla.org/xla/operation_semantics#alltoall)) and
horizontal scaling. `all_reduce`
support is requested
in [#1370](https://github.com/openxla/stablehlo/issues/1370) and is relied on by
PyTorch/XLA today via XlaBuilder ([ref](https://github.com/pytorch/xla/blob/1bbe333ad137ace6b8134db640c0b24c8c428db6/torch_xla/csrc/cross_replica_reduces.cpp#L156)).
`all_to_all` support is requested in
[#574](https://github.com/openxla/stablehlo/issues/574) and identified as a feature
gap.

## Proposed Specification Changes

Please refer spec.md changes in this PR to view the diff vs original spec.

0 comments on commit 6b69e21

Please sign in to comment.