Skip to content

Commit

Permalink
ONNX-Native Tensor Parallel: Using Distributed MatMul as Example (#17695
Browse files Browse the repository at this point in the history
)

This PR introduces
- New data structure to represent kernel-level (aka node-level or
op-level) tensor sharding informaiton. I consider it as the
fundamentaion of ONNX distribtued inference.
- Building blocks for distribtued kernels implementation especially
stateless implementation for communication ops.
- Implementation of DistributedMatMul and its tests.

Code structure:
- sharding.h/.cc: Function to shard and reshard tensors (calling into
NCCL).
- sharding_spec.h/.cc: Representation of how a tensor is sharded.
- distributed_matmul.h/.cc: Implementation of tensor parallel MatMul.
Inputs and outputs are sharded across devices.
- onnxruntime_test_distributed.py: distributed operator tests.

Example of specifying sharding information
```python
        @onnxscript.script()
        def matmul_rs_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT:
            # Run MatMul by sharding x along column axis and w along row axis on
            # 2 GPUs.
            return MICROSOFT_OPSET.DistributedMatMul(
                tensor_x,
                tensor_w,
                device_mesh_shape=[2],
                device_mesh_elements=[0, 1],
                input_shard_specs=["RS[0]", "S[0]R"],
                output_shard_specs=["RR"],
            )
        onnx_model = matmul_rs_sr_rr.to_model_proto(
            input_types=[FLOAT[2, "s"], FLOAT["s", 2]],
            output_types=[FLOAT[2, 2]],
        )
```

In this example, the device mesh can be visualized as 1-D tensor, `[0,
1]`. The 2nd axis of `tensor_x` is sharded across `[0, 1]` (i.e., the
0-axis of the device mesh). Similarly, the 1st axis of `tensor_w` is
sharded across `[0, 1]` as well.

C++ classes to represent tensor sharding (copied from sharding_spec.h):
```cpp
class DeviceMesh {
 public:
  // [Device Mesh and Tensor Sharding for Tensor Parallel]
  // Device mesh is a tensor of device indices.
  // A tensor can then be partitioned along specific mesh axes.
  //
  // Assume we have 4 GPUs indexed by 0, 1, 2, and 3.
  // Let's consider some examples.
  //  1. 1D device mesh [0, 1, 2, 3]. In this case,
  //     device_mesh_shape is [4] and device_mesh_elements
  //     is [0, 1, 2, 3].
  //     If we want to shard a 2-D tensor along its axis 1, the
  //     corresponding sharding spec is a string "RS[0]".
  //  2. 2D device mesh [[0, 1], [2, 3]]. In this case,
  //     device_mesh_shape is [2, 2] and device_mesh_elements
  //     is [0, 1, 2, 3].
  //     If we want to shard a 2-D tensor's
  //     rows along mesh axis 1 and
  //     columns along mesh axis 0, the
  //     corresponding sharding spec is a string "S[1]S[0]".
  //     If that 2-D tensor's value is np.array([[5, 6], [7, 8]]),
  //     GPU 0/1/2/3 owns 5/7/6/8.  Below is a visualization the sharding
  //     proccess.
  //     - Start with a 2-D device mesh [[0, 1], [2, 3]] and
  //       a 2-D tensor [[5, 6], [7, 8]]
  //       - GPU: [[0, 1], [2, 3]], Tensor: [[5, 6], [7, 8]]
  //     - Split GPU mesh along axis 1 and tensor along
  //       axis 0 for "S[1]" in "S[1]S[0]"
  //       - GPU: [[0], [2]], Tensor: [[5, 6]]
  //         GPU: [[1], [3]], Tensor: [[7, 8]]
  //     - Split GPU mesh along axis 0 and tensor along
  //       axis 1 for "S[0]" in "S[1]S[0]"
  //       - GPU: [[0]], Tensor: [[5]]
  //       - GPU: [[2]], Tensor: [[6]]
  //       - GPU: [[1]], Tensor: [[7]]
  //       - GPU: [[3]], Tensor: [[8]]

  // Actual shape of device mesh represented by `device_mesh_elements`.
  std::vector<int64_t> device_mesh_shape;

  // Flattened device mesh.
  std::vector<int64_t> device_mesh_elements;
};

class AxisPartitionSpec {
  // [Device Mesh and Tensor Sharding for Tensor Parallel]
  // This class is the in-memory representation of
  //  1. if a tensor is sharded or not (aka replica), and
  //  2. which tensor axis is shard by which device mesh axis.
  // Let's consider sharding 2-D tensor along column axis on
  // device mesh [0, 1] as an example.
  // The required sharding spec RS[0] can be represented by
  // - AxisPartitionSpec(Condition::Replica, -1)
  // - AxisPartitionSpec(Condition::Shard, 0)
 public:
  // Status of a tensor axis.
  // A tensor axis can be either sharded or replicated
  // along a device mesh axis.
  enum class Condition { Replica,
                         Shard };

  // This field tells if a tensor axis is sharded or not.
  Condition cond;

  // If a tensor axis is sharded, this field tells which device
  // mesh axis to distribute the shards along.
  // If a tensor axis is not sharded, this field is ignored.
  int device_mesh_axis;

  // A helper to construct a replica spec for a tensor axis.
  static AxisPartitionSpec CreateReplica() {
    return AxisPartitionSpec(Condition::Replica, -1);
  }

  // A helper to construct a sharding spec for a tensor axis.
  // This tensor axis is sharded along `device_mesh_axis` in device mesh.
  static AxisPartitionSpec CreateShard(int device_mesh_axis) {
    return AxisPartitionSpec(Condition::Shard, device_mesh_axis);
  }
};

class TensorPartitionSpec {
  // [Device Mesh and Tensor Sharding for Tensor Parallel]
  // TensorPartitionSpec holds a collection of AxisPartitionSpec and an
  // associated DeviceMesh. It is responsible for determining how a tensor
  // should be partitioned across a device mesh.
  //
  // Example 1: RS[0]
  // In this scenario, `axis_specs` would contain two `AxisPartitionSpec` objects.
  // - The first object is a Replica, denoting that the first axis of the tensor is
  //   not sharded but is instead replicated.
  // - The second object is a Shard along the 0-th axis of the device mesh. It denotes
  //   that the second axis of the tensor is sharded along the first axis of the
  //   device mesh.
  //
  // Example 2: S[0]RR
  // In this scenario, `axis_specs` would contain three `AxisPartitionSpec` objects.
  // - The first object is a Shard along the 0-th axis of the device mesh, indicating
  //   that the first axis of the tensor is sharded along the first axis of the
  //   device mesh.
  // - The second and third objects are Replicas, indicating that the second and third
  //   axes of the tensor are not sharded but are instead replicated.
 public:
  // axis_specs[i]: AxisPartitionSpec for tensor axis i. For a 2-D tensor,
  //                axis_specs[0] is for row axis and axis_specs[1] is for
  //                column axis. axis_specs[i].device_mesh_axis = j means that
  //                tensor axis i is sharded along device mesh axis j.
  std::vector<AxisPartitionSpec> axis_specs;

  // device_mesh: DeviceMesh for sharding the associated tensor.
  // Read [Device Mesh and Tensor Sharding for Tensor Parallel] in DeviceMesh's comment.
  DeviceMesh device_mesh;
};
```
  • Loading branch information
wschin authored Oct 5, 2023
1 parent 742069a commit faef9c3
Show file tree
Hide file tree
Showing 23 changed files with 1,969 additions and 11 deletions.
4 changes: 3 additions & 1 deletion cmake/onnxruntime_providers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,9 @@ if (onnxruntime_USE_CUDA)
if (NOT onnxruntime_USE_NCCL)
list(REMOVE_ITEM onnxruntime_cuda_contrib_ops_cc_srcs
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/nccl_kernels.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding_spec.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc"
)
endif()
# add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio
Expand Down Expand Up @@ -452,7 +455,6 @@ if (onnxruntime_USE_CUDA)
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/nccl_kernels.cc"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/megatron.cc"
)

list(REMOVE_ITEM onnxruntime_providers_cuda_src ${onnxruntime_cuda_nccl_op_srcs})
endif()
endif()
Expand Down
5 changes: 5 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,12 @@ if (NOT onnxruntime_ENABLE_ATEN)
list(APPEND contrib_ops_excluded_files "aten_ops/aten_op.cc")
endif()
if (NOT onnxruntime_USE_NCCL)
# Those are string patterns to exclude. Do NOT use stars such as
# collective/*.cc or *.h.
list(APPEND contrib_ops_excluded_files "collective/nccl_kernels.cc")
list(APPEND contrib_ops_excluded_files "collective/sharding.cc")
list(APPEND contrib_ops_excluded_files "collective/sharding_spec.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc")
endif()

set(provider_excluded_files
Expand Down
306 changes: 306 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/distributed_matmul.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// Distributed computation.
#include "sharding.h"
#include "distributed_matmul.h"
#include "nccl_kernels.h"
#include "mpi_include.h"

// ORT system.
#include "core/providers/cpu/tensor/slice.h"
#include "core/providers/cuda/tensor/slice.h"
#include "core/providers/cuda/math/matmul.h"
#include "core/providers/cuda/tensor/transpose.h"
#include "core/providers/cuda/cuda_check_memory.h"

// std C++.
#include <iostream>

namespace onnxruntime {
namespace contrib {
namespace cuda {

#if defined(ORT_USE_NCCL)

static TensorShape InferMatmulOutputShape(
const TensorShape& shape_A,
const TensorShape& shape_B) {
// left_shape: [M, K]
// right_shape: [K, N]
// output_shape: [M, N]
ORT_ENFORCE(
shape_A.NumDimensions() >= 2 && shape_B.NumDimensions() >= 2,
"1-D tensor is not supported by this MatMul.");
ORT_ENFORCE(
shape_A.NumDimensions() == shape_B.NumDimensions(),
"A and B must have the same rank after shape broadcasting.");
size_t rank = shape_A.NumDimensions();
std::vector<int64_t> shape_Y(rank, 0);
for (size_t i = 0; i < rank; ++i) {
const int64_t dim_A = shape_A[i];
const int64_t dim_B = shape_B[i];
if (i == rank - 1) {
shape_Y[i] = dim_B;
} else if (i == rank - 2) {
shape_Y[i] = dim_A;
} else if (dim_A == 1 && dim_B >= 1) {
// dim_A is 1.
// dim_B can be either 1 or other positive integer.
// due ot shape broadcast.
shape_Y[i] = dim_B;
} else if (dim_B == 1 && dim_A >= 1) {
// dim_B is 1.
// dim_A can be either 1 or other positive integer.
// due ot shape broadcast.
shape_Y[i] = dim_A;
} else {
ORT_ENFORCE(dim_A == dim_B, "Broadcasting can only happen when one of dim_A and dim_B is 1.");
shape_Y[i] = dim_A;
}
}
return TensorShape(shape_Y);
};

template <typename T>
DistributedMatMul<T>::DistributedMatMul(const OpKernelInfo& info) : NcclKernel(info) {
std::vector<int64_t> device_mesh_elements = info.GetAttrsOrDefault<int64_t>("device_mesh_elements");
std::vector<int64_t> device_mesh_shape = info.GetAttrsOrDefault<int64_t>("device_mesh_shape");
std::vector<std::string> input_shard_specs = info.GetAttrsOrDefault<std::string>("input_shard_specs");
std::vector<std::string> output_shard_specs = info.GetAttrsOrDefault<std::string>("output_shard_specs");

for (size_t i = 0; i < input_shard_specs.size(); ++i) {
auto spec = CreateTensorPartitionSpec(input_shard_specs[i], device_mesh_shape, device_mesh_elements);
input_shard_specs_.push_back(spec);
}
for (size_t i = 0; i < output_shard_specs.size(); ++i) {
auto spec = CreateTensorPartitionSpec(output_shard_specs[i], device_mesh_shape, device_mesh_elements);
output_shard_specs_.push_back(spec);
}
}

template <typename T>
Status DistributedMatMul<T>::ComputeInternal(OpKernelContext* context) const {
const auto tensor_shard_A = context->Input<Tensor>(0);
const auto tensor_shard_B = context->Input<Tensor>(1);
const auto& tensor_shard_shape_A = tensor_shard_A->Shape();
const auto& tensor_shard_shape_B = tensor_shard_B->Shape();

auto rank_A = tensor_shard_shape_A.NumDimensions();
auto rank_B = tensor_shard_shape_B.NumDimensions();
// TODO(wechi): Fix MatMul(1-D, *) and MatMul(*, 1-D) cases.
ORT_ENFORCE(rank_A >= 2 && rank_B >= 2, "Broadcast rule for 1-D tensor is different than other cases.");

const TensorPartitionSpec& spec_A = input_shard_specs_[0];
const TensorPartitionSpec& spec_B = input_shard_specs_[1];
const TensorPartitionSpec& spec_Y = output_shard_specs_[0];

const auto tensor_shape_A = ComputeOriginShape(tensor_shard_shape_A, spec_A);
const auto tensor_shape_B = ComputeOriginShape(tensor_shard_shape_B, spec_B);

TensorShape normalized_shape_A;
TensorShape normalized_shape_B;
std::tie(normalized_shape_A, normalized_shape_B) = NormalizeShapes(tensor_shape_A, tensor_shape_B);

TensorPartitionSpec normalized_spec_A;
TensorPartitionSpec normalized_spec_B;
std::tie(normalized_spec_A, normalized_spec_B) = NormalizeTensorPartitionSpecs(spec_A, spec_B);

const auto tensor_shape_Y = InferMatmulOutputShape(normalized_shape_A, normalized_shape_B);
const auto tensor_shard_shape_Y = ComputeShardShape(tensor_shape_Y, spec_Y);

// Case 1: A is not sharded, B is sharded.
// 1. shard on -1: MatMul(RR, RS) -> RS
// 2. shard on -2: MatMul(RR, SR) -> MatMul(RS, SR) + AllReduce -> RR
// 3. shard on other axis
if (normalized_spec_A.HasNoShard() && normalized_spec_B.HasShard()) {
if (normalized_spec_B.OnlyShardAxis(-1)) {
// Case 1-1
// MatMul(RR, RS) -> RS
ORT_ENFORCE(spec_Y.OnlyShardAxis(-1), "Not supported yet.");
auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y);
ORT_ENFORCE(onnxruntime::cuda::FuncMatMul<T>(
this, context, tensor_shard_A, tensor_shard_B, 1.0, false, false, false, false, tensor_shard_Y) == Status::OK());
} else if (normalized_spec_B.OnlyShardAxis(-2)) {
// Case 1-2
// MatMul(RR, SR) -> MatMul(RS, SR) + AllReduce -> RR
auto tmp_spec_A = CreateTensorShardSpec(spec_A.device_mesh, 0, -1, rank_A);
auto tmp_tensor_shard_A = ReshardTensor(this, context, spec_A, tmp_spec_A, nccl_->Rank(), tensor_shard_A);

auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y);
ORT_ENFORCE(onnxruntime::cuda::FuncMatMul<T>(
this, context, tmp_tensor_shard_A.get(), tensor_shard_B, 1.0, false, false, false, false, tensor_shard_Y) == Status::OK());
ORT_ENFORCE(FuncAllReduce(
nccl_->Comm(), Stream(context), tensor_shard_Y, tensor_shard_Y) == Status::OK());
} else {
// Case 1-3
ORT_THROW("Not supported yet.");
}
}

// Case 2: A is sharded, B is not sharded.
// 1. shard on -1: MatMul(RS, RR) -> MatMul(RS, SR) -> MatMul(RS, SR) + AllReduce -> RR
// 2. shard on -2: MatMul(SR, RR) -> SR
// 3. shard on other axis: : MatMul(SRR, RRR) -> MatMul(SRR, SRR) -> SRR
if (spec_A.HasShard() && spec_B.HasNoShard()) {
if (spec_A.OnlyShardAxis(-1) && spec_Y.HasNoShard()) {
// Case 2-1
// Y is not really sharded in this case.
auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y);

// TODO: Support cases with multi-dimension device mesh.
TensorPartitionSpec new_spec_B = CreateTensorShardSpec(spec_B.device_mesh, 0, -2, rank_B);
auto tensor_reshard_B = ShardTensor(this, context, new_spec_B, nccl_->Rank(), tensor_shard_B);

ORT_ENFORCE(onnxruntime::cuda::FuncMatMul<T>(
this, context, tensor_shard_A, tensor_reshard_B.get(), 1.0, false, false, false, false, tensor_shard_Y) == Status::OK());

ORT_ENFORCE(FuncAllReduce(
nccl_->Comm(), Stream(context), tensor_shard_Y, tensor_shard_Y) == Status::OK());
return Status::OK();
} else if (spec_A.OnlyShardAxis(-2) && spec_Y.OnlyShardAxis(-2)) {
// Case 2-2
auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y);
ORT_ENFORCE(onnxruntime::cuda::FuncMatMul<T>(
this, context, tensor_shard_A, tensor_shard_B, 1.0, false, false, false, false, tensor_shard_Y) == Status::OK());
return Status::OK();
} else if (spec_A.GetPartitionAxis() < gsl::narrow<int64_t>(tensor_shape_A.NumDimensions()) - 2 && normalized_spec_A.GetPartitionAxis() == spec_Y.GetPartitionAxis()) {
// Case 2-3
if (normalized_shape_B[normalized_spec_A.GetPartitionAxis()] == 1) {
// Case 2-3-1.
// B is broadcasted to along sharding axis in A.
// E.g., MatMul(A(SRR), B(RR)) where normalized_shape_A = [2, 3, 4] and normalized_shape_B = [1, 4, 3].
// No resharding is required.
auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y);
ORT_ENFORCE(onnxruntime::cuda::FuncMatMul<T>(
this, context, tensor_shard_A, tensor_shard_B, 1.0, false, false, false, false, tensor_shard_Y) == Status::OK());
return Status::OK();
} else {
// Case 2-3-2.
// No broadcasting
// Allocate tensor based on shape sharded non-matrix axis.
// MatMul(SRR, RRR) -> MatMul(SRR, SRR) -> SRR
auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y);
TensorPartitionSpec new_spec_B = CreateTensorShardSpec(
spec_B.device_mesh,
0,
spec_A.GetNegativePartitionAxis(),
rank_B);
auto tensor_reshard_B = ShardTensor(this, context, new_spec_B, nccl_->Rank(), tensor_shard_B);
ORT_ENFORCE(onnxruntime::cuda::FuncMatMul<T>(
this, context, tensor_shard_A, tensor_reshard_B.get(), 1.0, false, false, false, false, tensor_shard_Y) == Status::OK());
return Status::OK();
}
} else {
ORT_THROW("Not supported yet.");
}
}

// Case 3: A is sharded, B is sharded.
// 1. shard on (-1, -1): MatMul(RS, RS) -> MatMul(RS, SR) + AllReduce -> RR
// -> MatMul(RR, RS) -> RS
// 2. shard on (-1, -2): MatMul(RS, SR) -> MatMul(RS, SR) + AllReduce -> RR
// 3. shard on (-2, -1): MatMul(SR, RS) -> MatMul(RS, SR) + AllReduce -> RR
// 4. shard on (-2, -2): MatMul(SR, SR) -> MatMul(RS, SR) + AllReduce -> RR
// 5. shard on other axes
if (spec_A.HasShard() && spec_B.HasShard()) {
if (spec_A.OnlyShardAxis(-1) && spec_B.OnlyShardAxis(-1)) {
// Case 3-1
if (spec_Y.HasNoShard()) {
// Case 3-1-1
auto tmp_spec_B = CreateTensorShardSpec(spec_B.device_mesh, 0, -2, rank_B);
auto tmp_tensor_shard_B = ReshardTensor(this, context, spec_B, tmp_spec_B, nccl_->Rank(), tensor_shard_B);
auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y);
ORT_ENFORCE(onnxruntime::cuda::FuncMatMul<T>(
this, context, tensor_shard_A, tmp_tensor_shard_B.get(), 1.0, false, false, false, false, tensor_shard_Y) == Status::OK());
ORT_ENFORCE(FuncAllReduce(
nccl_->Comm(), Stream(context), tensor_shard_Y, tensor_shard_Y) == Status::OK());
} else if (spec_Y.OnlyShardAxis(-1)) {
// Cas 3-1-2
auto tmp_spec_A = TensorPartitionSpec::CreateAllReplica(spec_A);
auto tmp_tensor_shard_A = ReshardTensor(this, context, spec_A, tmp_spec_A, nccl_->Rank(), tensor_shard_A);
auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y);
ORT_ENFORCE(onnxruntime::cuda::FuncMatMul<T>(
this, context, tmp_tensor_shard_A.get(), tensor_shard_B, 1.0, false, false, false, false, tensor_shard_Y) == Status::OK());
} else {
ORT_THROW("Not supported yet.");
}
} else if (spec_A.OnlyShardAxis(-1) && spec_B.OnlyShardAxis(-2) && spec_Y.HasNoShard()) {
// Case 3-2
auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y);

auto status = onnxruntime::cuda::FuncMatMul<T>(
this, context, tensor_shard_A, tensor_shard_B, 1.0, false, false, false, false, tensor_shard_Y);
ORT_ENFORCE(status == Status::OK(), status.ErrorMessage());

status = FuncAllReduce(
nccl_->Comm(), Stream(context), tensor_shard_Y, tensor_shard_Y);
ORT_ENFORCE(status == Status::OK(), status.ErrorMessage());
} else if (spec_A.OnlyShardAxis(-2) && spec_B.OnlyShardAxis(-1)) {
// Case 3-3:
// MatMul(SR, RS) -> MatMul(RS, SR) + AllReduce -> RR
ORT_ENFORCE(spec_Y.HasNoShard(), "Not supported yet.");

// A[RS]
auto tmp_spec_A = CreateTensorShardSpec(spec_A.device_mesh, 0, -1, rank_A);
auto tmp_tensor_shard_A = ReshardTensor(this, context, spec_A, tmp_spec_A, nccl_->Rank(), tensor_shard_A);

// B[SR]
auto tmp_spec_B = CreateTensorShardSpec(spec_B.device_mesh, 0, -2, rank_B);
auto tmp_tensor_shard_B = ReshardTensor(this, context, spec_B, tmp_spec_B, nccl_->Rank(), tensor_shard_B);

// Allocate Y[RR]
auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y);

// Run local MatMul(A[RS], B[SR])
ORT_ENFORCE(onnxruntime::cuda::FuncMatMul<T>(
this, context, tmp_tensor_shard_A.get(), tmp_tensor_shard_B.get(), 1.0, false, false, false, false, tensor_shard_Y) == Status::OK());
ORT_ENFORCE(FuncAllReduce(
nccl_->Comm(), Stream(context), tensor_shard_Y, tensor_shard_Y) == Status::OK());
} else if (spec_A.OnlyShardAxis(-2) && spec_B.OnlyShardAxis(-2)) {
// Case 3-4
// MatMul(SR, SR) -> MatMul(RS, SR) + AllReduce -> RR
ORT_ENFORCE(spec_Y.HasNoShard(), "Not supported yet.");
auto tmp_spec_A = CreateTensorShardSpec(spec_A.device_mesh, 0, -1, rank_A);
auto tmp_tensor_shard_A = ReshardTensor(this, context, spec_A, tmp_spec_A, nccl_->Rank(), tensor_shard_A);
auto tensor_sard_Y = context->Output(0, tensor_shard_shape_Y);
ORT_ENFORCE(onnxruntime::cuda::FuncMatMul<T>(
this, context, tmp_tensor_shard_A.get(), tensor_shard_B, 1.0, false, false, false, false, tensor_sard_Y) == Status::OK());
} else {
// Case 3-5
ORT_THROW("Not supported yet.");
}
}

// Case 4: A is not sharded, B is not sharded.
// - Easy!
return Status::OK();
}

ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedMatMul,
kMSDomain,
1,
float,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.AllocateInputsContiguously()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
DistributedMatMul<float>);

ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedMatMul,
kMSDomain,
1,
MLFloat16,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.AllocateInputsContiguously()
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
DistributedMatMul<MLFloat16>);

#endif

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
38 changes: 38 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/distributed_matmul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "sharding_spec.h"
#include "core/providers/cuda/cuda_kernel.h"

#include <algorithm>
#include <tuple>
#include <optional>
#include <string>
#include <nccl.h>
#include <sstream>

#pragma once

namespace onnxruntime {
namespace contrib {
namespace cuda {

#if defined(ORT_USE_NCCL)

template <typename T>
class DistributedMatMul final : public NcclKernel {
public:
explicit DistributedMatMul(const OpKernelInfo& info);

Status ComputeInternal(OpKernelContext* context) const override;

private:
std::vector<TensorPartitionSpec> input_shard_specs_;
std::vector<TensorPartitionSpec> output_shard_specs_;
};

#endif

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
Loading

0 comments on commit faef9c3

Please sign in to comment.