From faef9c32fa7740935ee6cde5b00adab59d641854 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 5 Oct 2023 14:22:25 -0700 Subject: [PATCH] ONNX-Native Tensor Parallel: Using Distributed MatMul as Example (#17695) This PR introduces - New data structure to represent kernel-level (aka node-level or op-level) tensor sharding informaiton. I consider it as the fundamentaion of ONNX distribtued inference. - Building blocks for distribtued kernels implementation especially stateless implementation for communication ops. - Implementation of DistributedMatMul and its tests. Code structure: - sharding.h/.cc: Function to shard and reshard tensors (calling into NCCL). - sharding_spec.h/.cc: Representation of how a tensor is sharded. - distributed_matmul.h/.cc: Implementation of tensor parallel MatMul. Inputs and outputs are sharded across devices. - onnxruntime_test_distributed.py: distributed operator tests. Example of specifying sharding information ```python @onnxscript.script() def matmul_rs_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: # Run MatMul by sharding x along column axis and w along row axis on # 2 GPUs. return MICROSOFT_OPSET.DistributedMatMul( tensor_x, tensor_w, device_mesh_shape=[2], device_mesh_elements=[0, 1], input_shard_specs=["RS[0]", "S[0]R"], output_shard_specs=["RR"], ) onnx_model = matmul_rs_sr_rr.to_model_proto( input_types=[FLOAT[2, "s"], FLOAT["s", 2]], output_types=[FLOAT[2, 2]], ) ``` In this example, the device mesh can be visualized as 1-D tensor, `[0, 1]`. The 2nd axis of `tensor_x` is sharded across `[0, 1]` (i.e., the 0-axis of the device mesh). Similarly, the 1st axis of `tensor_w` is sharded across `[0, 1]` as well. C++ classes to represent tensor sharding (copied from sharding_spec.h): ```cpp class DeviceMesh { public: // [Device Mesh and Tensor Sharding for Tensor Parallel] // Device mesh is a tensor of device indices. // A tensor can then be partitioned along specific mesh axes. // // Assume we have 4 GPUs indexed by 0, 1, 2, and 3. // Let's consider some examples. // 1. 1D device mesh [0, 1, 2, 3]. In this case, // device_mesh_shape is [4] and device_mesh_elements // is [0, 1, 2, 3]. // If we want to shard a 2-D tensor along its axis 1, the // corresponding sharding spec is a string "RS[0]". // 2. 2D device mesh [[0, 1], [2, 3]]. In this case, // device_mesh_shape is [2, 2] and device_mesh_elements // is [0, 1, 2, 3]. // If we want to shard a 2-D tensor's // rows along mesh axis 1 and // columns along mesh axis 0, the // corresponding sharding spec is a string "S[1]S[0]". // If that 2-D tensor's value is np.array([[5, 6], [7, 8]]), // GPU 0/1/2/3 owns 5/7/6/8. Below is a visualization the sharding // proccess. // - Start with a 2-D device mesh [[0, 1], [2, 3]] and // a 2-D tensor [[5, 6], [7, 8]] // - GPU: [[0, 1], [2, 3]], Tensor: [[5, 6], [7, 8]] // - Split GPU mesh along axis 1 and tensor along // axis 0 for "S[1]" in "S[1]S[0]" // - GPU: [[0], [2]], Tensor: [[5, 6]] // GPU: [[1], [3]], Tensor: [[7, 8]] // - Split GPU mesh along axis 0 and tensor along // axis 1 for "S[0]" in "S[1]S[0]" // - GPU: [[0]], Tensor: [[5]] // - GPU: [[2]], Tensor: [[6]] // - GPU: [[1]], Tensor: [[7]] // - GPU: [[3]], Tensor: [[8]] // Actual shape of device mesh represented by `device_mesh_elements`. std::vector device_mesh_shape; // Flattened device mesh. std::vector device_mesh_elements; }; class AxisPartitionSpec { // [Device Mesh and Tensor Sharding for Tensor Parallel] // This class is the in-memory representation of // 1. if a tensor is sharded or not (aka replica), and // 2. which tensor axis is shard by which device mesh axis. // Let's consider sharding 2-D tensor along column axis on // device mesh [0, 1] as an example. // The required sharding spec RS[0] can be represented by // - AxisPartitionSpec(Condition::Replica, -1) // - AxisPartitionSpec(Condition::Shard, 0) public: // Status of a tensor axis. // A tensor axis can be either sharded or replicated // along a device mesh axis. enum class Condition { Replica, Shard }; // This field tells if a tensor axis is sharded or not. Condition cond; // If a tensor axis is sharded, this field tells which device // mesh axis to distribute the shards along. // If a tensor axis is not sharded, this field is ignored. int device_mesh_axis; // A helper to construct a replica spec for a tensor axis. static AxisPartitionSpec CreateReplica() { return AxisPartitionSpec(Condition::Replica, -1); } // A helper to construct a sharding spec for a tensor axis. // This tensor axis is sharded along `device_mesh_axis` in device mesh. static AxisPartitionSpec CreateShard(int device_mesh_axis) { return AxisPartitionSpec(Condition::Shard, device_mesh_axis); } }; class TensorPartitionSpec { // [Device Mesh and Tensor Sharding for Tensor Parallel] // TensorPartitionSpec holds a collection of AxisPartitionSpec and an // associated DeviceMesh. It is responsible for determining how a tensor // should be partitioned across a device mesh. // // Example 1: RS[0] // In this scenario, `axis_specs` would contain two `AxisPartitionSpec` objects. // - The first object is a Replica, denoting that the first axis of the tensor is // not sharded but is instead replicated. // - The second object is a Shard along the 0-th axis of the device mesh. It denotes // that the second axis of the tensor is sharded along the first axis of the // device mesh. // // Example 2: S[0]RR // In this scenario, `axis_specs` would contain three `AxisPartitionSpec` objects. // - The first object is a Shard along the 0-th axis of the device mesh, indicating // that the first axis of the tensor is sharded along the first axis of the // device mesh. // - The second and third objects are Replicas, indicating that the second and third // axes of the tensor are not sharded but are instead replicated. public: // axis_specs[i]: AxisPartitionSpec for tensor axis i. For a 2-D tensor, // axis_specs[0] is for row axis and axis_specs[1] is for // column axis. axis_specs[i].device_mesh_axis = j means that // tensor axis i is sharded along device mesh axis j. std::vector axis_specs; // device_mesh: DeviceMesh for sharding the associated tensor. // Read [Device Mesh and Tensor Sharding for Tensor Parallel] in DeviceMesh's comment. DeviceMesh device_mesh; }; ``` --- cmake/onnxruntime_providers.cmake | 4 +- cmake/onnxruntime_rocm_hipify.cmake | 5 + .../cuda/collective/distributed_matmul.cc | 306 ++++++++++++++ .../cuda/collective/distributed_matmul.h | 38 ++ .../cuda/collective/nccl_kernels.cc | 113 ++++++ .../cuda/collective/nccl_kernels.h | 30 ++ .../contrib_ops/cuda/collective/sharding.cc | 219 ++++++++++ .../contrib_ops/cuda/collective/sharding.h | 64 +++ .../cuda/collective/sharding_spec.cc | 193 +++++++++ .../cuda/collective/sharding_spec.h | 374 ++++++++++++++++++ .../contrib_ops/cuda/cuda_contrib_kernels.cc | 6 + .../core/graph/contrib_ops/collective_defs.cc | 26 ++ .../core/providers/cpu/cpu_provider_shared.cc | 6 + .../core/providers/cpu/cpu_provider_shared.h | 4 + .../core/providers/cpu/tensor/slice.cc | 20 +- onnxruntime/core/providers/cpu/tensor/slice.h | 4 + .../core/providers/cuda/math/matmul.cc | 155 ++++++++ onnxruntime/core/providers/cuda/math/matmul.h | 18 + .../core/providers/cuda/tensor/slice.cc | 54 +++ .../core/providers/cuda/tensor/slice.h | 15 + .../provider_bridge_provider.cc | 7 + .../python/onnxruntime_test_distributed.py | 317 +++++++++++++++ ...ortmodule-distributed-test-ci-pipeline.yml | 2 +- 23 files changed, 1969 insertions(+), 11 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/collective/distributed_matmul.cc create mode 100644 onnxruntime/contrib_ops/cuda/collective/distributed_matmul.h create mode 100644 onnxruntime/contrib_ops/cuda/collective/sharding.cc create mode 100644 onnxruntime/contrib_ops/cuda/collective/sharding.h create mode 100644 onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc create mode 100644 onnxruntime/contrib_ops/cuda/collective/sharding_spec.h create mode 100644 onnxruntime/test/python/onnxruntime_test_distributed.py diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 96c05e5282bb5..62c452b14e696 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -404,6 +404,9 @@ if (onnxruntime_USE_CUDA) if (NOT onnxruntime_USE_NCCL) list(REMOVE_ITEM onnxruntime_cuda_contrib_ops_cc_srcs "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/nccl_kernels.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding_spec.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc" ) endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio @@ -452,7 +455,6 @@ if (onnxruntime_USE_CUDA) "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/nccl_kernels.cc" "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/megatron.cc" ) - list(REMOVE_ITEM onnxruntime_providers_cuda_src ${onnxruntime_cuda_nccl_op_srcs}) endif() endif() diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index cf71b6bcf7c7d..55d03c14270d3 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -109,7 +109,12 @@ if (NOT onnxruntime_ENABLE_ATEN) list(APPEND contrib_ops_excluded_files "aten_ops/aten_op.cc") endif() if (NOT onnxruntime_USE_NCCL) + # Those are string patterns to exclude. Do NOT use stars such as + # collective/*.cc or *.h. list(APPEND contrib_ops_excluded_files "collective/nccl_kernels.cc") + list(APPEND contrib_ops_excluded_files "collective/sharding.cc") + list(APPEND contrib_ops_excluded_files "collective/sharding_spec.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc") endif() set(provider_excluded_files diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.cc new file mode 100644 index 0000000000000..253a58bd82a20 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.cc @@ -0,0 +1,306 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "sharding.h" +#include "distributed_matmul.h" +#include "nccl_kernels.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cpu/tensor/slice.h" +#include "core/providers/cuda/tensor/slice.h" +#include "core/providers/cuda/math/matmul.h" +#include "core/providers/cuda/tensor/transpose.h" +#include "core/providers/cuda/cuda_check_memory.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +static TensorShape InferMatmulOutputShape( + const TensorShape& shape_A, + const TensorShape& shape_B) { + // left_shape: [M, K] + // right_shape: [K, N] + // output_shape: [M, N] + ORT_ENFORCE( + shape_A.NumDimensions() >= 2 && shape_B.NumDimensions() >= 2, + "1-D tensor is not supported by this MatMul."); + ORT_ENFORCE( + shape_A.NumDimensions() == shape_B.NumDimensions(), + "A and B must have the same rank after shape broadcasting."); + size_t rank = shape_A.NumDimensions(); + std::vector shape_Y(rank, 0); + for (size_t i = 0; i < rank; ++i) { + const int64_t dim_A = shape_A[i]; + const int64_t dim_B = shape_B[i]; + if (i == rank - 1) { + shape_Y[i] = dim_B; + } else if (i == rank - 2) { + shape_Y[i] = dim_A; + } else if (dim_A == 1 && dim_B >= 1) { + // dim_A is 1. + // dim_B can be either 1 or other positive integer. + // due ot shape broadcast. + shape_Y[i] = dim_B; + } else if (dim_B == 1 && dim_A >= 1) { + // dim_B is 1. + // dim_A can be either 1 or other positive integer. + // due ot shape broadcast. + shape_Y[i] = dim_A; + } else { + ORT_ENFORCE(dim_A == dim_B, "Broadcasting can only happen when one of dim_A and dim_B is 1."); + shape_Y[i] = dim_A; + } + } + return TensorShape(shape_Y); +}; + +template +DistributedMatMul::DistributedMatMul(const OpKernelInfo& info) : NcclKernel(info) { + std::vector device_mesh_elements = info.GetAttrsOrDefault("device_mesh_elements"); + std::vector device_mesh_shape = info.GetAttrsOrDefault("device_mesh_shape"); + std::vector input_shard_specs = info.GetAttrsOrDefault("input_shard_specs"); + std::vector output_shard_specs = info.GetAttrsOrDefault("output_shard_specs"); + + for (size_t i = 0; i < input_shard_specs.size(); ++i) { + auto spec = CreateTensorPartitionSpec(input_shard_specs[i], device_mesh_shape, device_mesh_elements); + input_shard_specs_.push_back(spec); + } + for (size_t i = 0; i < output_shard_specs.size(); ++i) { + auto spec = CreateTensorPartitionSpec(output_shard_specs[i], device_mesh_shape, device_mesh_elements); + output_shard_specs_.push_back(spec); + } +} + +template +Status DistributedMatMul::ComputeInternal(OpKernelContext* context) const { + const auto tensor_shard_A = context->Input(0); + const auto tensor_shard_B = context->Input(1); + const auto& tensor_shard_shape_A = tensor_shard_A->Shape(); + const auto& tensor_shard_shape_B = tensor_shard_B->Shape(); + + auto rank_A = tensor_shard_shape_A.NumDimensions(); + auto rank_B = tensor_shard_shape_B.NumDimensions(); + // TODO(wechi): Fix MatMul(1-D, *) and MatMul(*, 1-D) cases. + ORT_ENFORCE(rank_A >= 2 && rank_B >= 2, "Broadcast rule for 1-D tensor is different than other cases."); + + const TensorPartitionSpec& spec_A = input_shard_specs_[0]; + const TensorPartitionSpec& spec_B = input_shard_specs_[1]; + const TensorPartitionSpec& spec_Y = output_shard_specs_[0]; + + const auto tensor_shape_A = ComputeOriginShape(tensor_shard_shape_A, spec_A); + const auto tensor_shape_B = ComputeOriginShape(tensor_shard_shape_B, spec_B); + + TensorShape normalized_shape_A; + TensorShape normalized_shape_B; + std::tie(normalized_shape_A, normalized_shape_B) = NormalizeShapes(tensor_shape_A, tensor_shape_B); + + TensorPartitionSpec normalized_spec_A; + TensorPartitionSpec normalized_spec_B; + std::tie(normalized_spec_A, normalized_spec_B) = NormalizeTensorPartitionSpecs(spec_A, spec_B); + + const auto tensor_shape_Y = InferMatmulOutputShape(normalized_shape_A, normalized_shape_B); + const auto tensor_shard_shape_Y = ComputeShardShape(tensor_shape_Y, spec_Y); + + // Case 1: A is not sharded, B is sharded. + // 1. shard on -1: MatMul(RR, RS) -> RS + // 2. shard on -2: MatMul(RR, SR) -> MatMul(RS, SR) + AllReduce -> RR + // 3. shard on other axis + if (normalized_spec_A.HasNoShard() && normalized_spec_B.HasShard()) { + if (normalized_spec_B.OnlyShardAxis(-1)) { + // Case 1-1 + // MatMul(RR, RS) -> RS + ORT_ENFORCE(spec_Y.OnlyShardAxis(-1), "Not supported yet."); + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tensor_shard_A, tensor_shard_B, 1.0, false, false, false, false, tensor_shard_Y) == Status::OK()); + } else if (normalized_spec_B.OnlyShardAxis(-2)) { + // Case 1-2 + // MatMul(RR, SR) -> MatMul(RS, SR) + AllReduce -> RR + auto tmp_spec_A = CreateTensorShardSpec(spec_A.device_mesh, 0, -1, rank_A); + auto tmp_tensor_shard_A = ReshardTensor(this, context, spec_A, tmp_spec_A, nccl_->Rank(), tensor_shard_A); + + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tmp_tensor_shard_A.get(), tensor_shard_B, 1.0, false, false, false, false, tensor_shard_Y) == Status::OK()); + ORT_ENFORCE(FuncAllReduce( + nccl_->Comm(), Stream(context), tensor_shard_Y, tensor_shard_Y) == Status::OK()); + } else { + // Case 1-3 + ORT_THROW("Not supported yet."); + } + } + + // Case 2: A is sharded, B is not sharded. + // 1. shard on -1: MatMul(RS, RR) -> MatMul(RS, SR) -> MatMul(RS, SR) + AllReduce -> RR + // 2. shard on -2: MatMul(SR, RR) -> SR + // 3. shard on other axis: : MatMul(SRR, RRR) -> MatMul(SRR, SRR) -> SRR + if (spec_A.HasShard() && spec_B.HasNoShard()) { + if (spec_A.OnlyShardAxis(-1) && spec_Y.HasNoShard()) { + // Case 2-1 + // Y is not really sharded in this case. + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + + // TODO: Support cases with multi-dimension device mesh. + TensorPartitionSpec new_spec_B = CreateTensorShardSpec(spec_B.device_mesh, 0, -2, rank_B); + auto tensor_reshard_B = ShardTensor(this, context, new_spec_B, nccl_->Rank(), tensor_shard_B); + + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tensor_shard_A, tensor_reshard_B.get(), 1.0, false, false, false, false, tensor_shard_Y) == Status::OK()); + + ORT_ENFORCE(FuncAllReduce( + nccl_->Comm(), Stream(context), tensor_shard_Y, tensor_shard_Y) == Status::OK()); + return Status::OK(); + } else if (spec_A.OnlyShardAxis(-2) && spec_Y.OnlyShardAxis(-2)) { + // Case 2-2 + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tensor_shard_A, tensor_shard_B, 1.0, false, false, false, false, tensor_shard_Y) == Status::OK()); + return Status::OK(); + } else if (spec_A.GetPartitionAxis() < gsl::narrow(tensor_shape_A.NumDimensions()) - 2 && normalized_spec_A.GetPartitionAxis() == spec_Y.GetPartitionAxis()) { + // Case 2-3 + if (normalized_shape_B[normalized_spec_A.GetPartitionAxis()] == 1) { + // Case 2-3-1. + // B is broadcasted to along sharding axis in A. + // E.g., MatMul(A(SRR), B(RR)) where normalized_shape_A = [2, 3, 4] and normalized_shape_B = [1, 4, 3]. + // No resharding is required. + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tensor_shard_A, tensor_shard_B, 1.0, false, false, false, false, tensor_shard_Y) == Status::OK()); + return Status::OK(); + } else { + // Case 2-3-2. + // No broadcasting + // Allocate tensor based on shape sharded non-matrix axis. + // MatMul(SRR, RRR) -> MatMul(SRR, SRR) -> SRR + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + TensorPartitionSpec new_spec_B = CreateTensorShardSpec( + spec_B.device_mesh, + 0, + spec_A.GetNegativePartitionAxis(), + rank_B); + auto tensor_reshard_B = ShardTensor(this, context, new_spec_B, nccl_->Rank(), tensor_shard_B); + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tensor_shard_A, tensor_reshard_B.get(), 1.0, false, false, false, false, tensor_shard_Y) == Status::OK()); + return Status::OK(); + } + } else { + ORT_THROW("Not supported yet."); + } + } + + // Case 3: A is sharded, B is sharded. + // 1. shard on (-1, -1): MatMul(RS, RS) -> MatMul(RS, SR) + AllReduce -> RR + // -> MatMul(RR, RS) -> RS + // 2. shard on (-1, -2): MatMul(RS, SR) -> MatMul(RS, SR) + AllReduce -> RR + // 3. shard on (-2, -1): MatMul(SR, RS) -> MatMul(RS, SR) + AllReduce -> RR + // 4. shard on (-2, -2): MatMul(SR, SR) -> MatMul(RS, SR) + AllReduce -> RR + // 5. shard on other axes + if (spec_A.HasShard() && spec_B.HasShard()) { + if (spec_A.OnlyShardAxis(-1) && spec_B.OnlyShardAxis(-1)) { + // Case 3-1 + if (spec_Y.HasNoShard()) { + // Case 3-1-1 + auto tmp_spec_B = CreateTensorShardSpec(spec_B.device_mesh, 0, -2, rank_B); + auto tmp_tensor_shard_B = ReshardTensor(this, context, spec_B, tmp_spec_B, nccl_->Rank(), tensor_shard_B); + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tensor_shard_A, tmp_tensor_shard_B.get(), 1.0, false, false, false, false, tensor_shard_Y) == Status::OK()); + ORT_ENFORCE(FuncAllReduce( + nccl_->Comm(), Stream(context), tensor_shard_Y, tensor_shard_Y) == Status::OK()); + } else if (spec_Y.OnlyShardAxis(-1)) { + // Cas 3-1-2 + auto tmp_spec_A = TensorPartitionSpec::CreateAllReplica(spec_A); + auto tmp_tensor_shard_A = ReshardTensor(this, context, spec_A, tmp_spec_A, nccl_->Rank(), tensor_shard_A); + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tmp_tensor_shard_A.get(), tensor_shard_B, 1.0, false, false, false, false, tensor_shard_Y) == Status::OK()); + } else { + ORT_THROW("Not supported yet."); + } + } else if (spec_A.OnlyShardAxis(-1) && spec_B.OnlyShardAxis(-2) && spec_Y.HasNoShard()) { + // Case 3-2 + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + + auto status = onnxruntime::cuda::FuncMatMul( + this, context, tensor_shard_A, tensor_shard_B, 1.0, false, false, false, false, tensor_shard_Y); + ORT_ENFORCE(status == Status::OK(), status.ErrorMessage()); + + status = FuncAllReduce( + nccl_->Comm(), Stream(context), tensor_shard_Y, tensor_shard_Y); + ORT_ENFORCE(status == Status::OK(), status.ErrorMessage()); + } else if (spec_A.OnlyShardAxis(-2) && spec_B.OnlyShardAxis(-1)) { + // Case 3-3: + // MatMul(SR, RS) -> MatMul(RS, SR) + AllReduce -> RR + ORT_ENFORCE(spec_Y.HasNoShard(), "Not supported yet."); + + // A[RS] + auto tmp_spec_A = CreateTensorShardSpec(spec_A.device_mesh, 0, -1, rank_A); + auto tmp_tensor_shard_A = ReshardTensor(this, context, spec_A, tmp_spec_A, nccl_->Rank(), tensor_shard_A); + + // B[SR] + auto tmp_spec_B = CreateTensorShardSpec(spec_B.device_mesh, 0, -2, rank_B); + auto tmp_tensor_shard_B = ReshardTensor(this, context, spec_B, tmp_spec_B, nccl_->Rank(), tensor_shard_B); + + // Allocate Y[RR] + auto tensor_shard_Y = context->Output(0, tensor_shard_shape_Y); + + // Run local MatMul(A[RS], B[SR]) + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tmp_tensor_shard_A.get(), tmp_tensor_shard_B.get(), 1.0, false, false, false, false, tensor_shard_Y) == Status::OK()); + ORT_ENFORCE(FuncAllReduce( + nccl_->Comm(), Stream(context), tensor_shard_Y, tensor_shard_Y) == Status::OK()); + } else if (spec_A.OnlyShardAxis(-2) && spec_B.OnlyShardAxis(-2)) { + // Case 3-4 + // MatMul(SR, SR) -> MatMul(RS, SR) + AllReduce -> RR + ORT_ENFORCE(spec_Y.HasNoShard(), "Not supported yet."); + auto tmp_spec_A = CreateTensorShardSpec(spec_A.device_mesh, 0, -1, rank_A); + auto tmp_tensor_shard_A = ReshardTensor(this, context, spec_A, tmp_spec_A, nccl_->Rank(), tensor_shard_A); + auto tensor_sard_Y = context->Output(0, tensor_shard_shape_Y); + ORT_ENFORCE(onnxruntime::cuda::FuncMatMul( + this, context, tmp_tensor_shard_A.get(), tensor_shard_B, 1.0, false, false, false, false, tensor_sard_Y) == Status::OK()); + } else { + // Case 3-5 + ORT_THROW("Not supported yet."); + } + } + + // Case 4: A is not sharded, B is not sharded. + // - Easy! + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedMatMul, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + DistributedMatMul); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedMatMul, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + DistributedMatMul); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.h b/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.h new file mode 100644 index 0000000000000..d8df24c03498f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_matmul.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedMatMul final : public NcclKernel { + public: + explicit DistributedMatMul(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + std::vector input_shard_specs_; + std::vector output_shard_specs_; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc index bb924a0d49cfe..ff49d7174c329 100644 --- a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc @@ -3,6 +3,9 @@ #include "nccl_kernels.h" #include "mpi_include.h" +#include "core/providers/cpu/tensor/slice.h" +#include "core/providers/cuda/tensor/slice.h" +#include "core/providers/cuda/math/matmul.h" #include "core/providers/cuda/tensor/transpose.h" #include "core/providers/cuda/cuda_check_memory.h" @@ -246,6 +249,116 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", DataTypeImpl::AllTensorTypes()), AllToAll); +Status FuncAllReduce( + ncclComm_t comm, + cudaStream_t stream, + const Tensor* input, + Tensor* output) { + const void* input_data = input->DataRaw(); + const auto input_shape = input->Shape(); + int64_t input_count = input_shape.Size(); + + void* output_data = output->MutableDataRaw(); + + ncclDataType_t dtype = GetNcclDataType(input->DataType()); + NCCL_RETURN_IF_ERROR(ncclAllReduce(input_data, output_data, input_count, dtype, ncclSum, comm, stream)); + return Status::OK(); +} + +static std::vector CalculatePermToSwapAxes( + const int64_t axis, + const int64_t another_axis, + const size_t rank) { + // This is for swapping axis and another_axis. + // NCCL's AllGather only gathers along axis 0. If gathering along another axis is needed, + // we need to call transpose. E.g., + // Case 1: + // AllGather(axis=0) + // Case 2: + // AllGather(axis=3) = Transpose(perm=[3, 1, 2, 0]) -> AllGather(axis=0) -> Transpose(perm=[3, 1, 2, 0]) + std::vector permutation(rank); + std::iota(std::begin(permutation), std::end(permutation), 0); + permutation[axis] = another_axis; + permutation[another_axis] = axis; + return permutation; +} + +void FuncAllGather( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const Tensor* input, + const int64_t group_size, + const int64_t axis, + Tensor* output) { + ORT_ENFORCE(output->Shape().Size() == input->Shape().Size() * group_size, "Input and output shapes mismatch."); + ORT_ENFORCE(group_size >= 0, "group_size should be non-negative."); + ORT_ENFORCE(axis >= 0, "axis should be non-negative."); + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc) == Status::OK(), "Fail to find allocator."); + if (axis == 0) { + const void* input_data = input->DataRaw(); + const auto input_shape = input->Shape(); + void* output_data = output->MutableDataRaw(); + ncclAllGather( + input_data, + output_data, + input_shape.Size(), + GetNcclDataType(input->DataType()), + nccl_kernel->Comm(), + nccl_kernel->Stream(ctx)); + } else { + const auto source_shape = input->Shape(); + TensorShape transposed_shape(source_shape); + transposed_shape[0] = source_shape[axis]; + transposed_shape[axis] = source_shape[0]; + + auto transposed_buffer = Tensor::Create(input->DataType(), transposed_shape, alloc); + + // swap axis 0 and axis axis + std::vector perm = CalculatePermToSwapAxes(0, axis, source_shape.NumDimensions()); + + ORT_ENFORCE(onnxruntime::cuda::Transpose::DoTranspose(nccl_kernel->GetDeviceProp(), + nccl_kernel->Stream(ctx), + nccl_kernel->GetCublasHandle(ctx), + perm, *input, *transposed_buffer) == Status::OK()); + + TensorShape gathered_shape(transposed_shape); + gathered_shape[0] = group_size * transposed_shape[0]; + auto gathered_buffer = Tensor::Create(input->DataType(), gathered_shape, alloc); + + ncclAllGather( + transposed_buffer->DataRaw(), + gathered_buffer->MutableDataRaw(), + transposed_shape.Size(), + GetNcclDataType(input->DataType()), + nccl_kernel->Comm(), + nccl_kernel->Stream(ctx)); + + ORT_ENFORCE(gathered_buffer->Shape().Size() == output->Shape().Size()); + ORT_ENFORCE(onnxruntime::cuda::Transpose::DoTranspose(nccl_kernel->GetDeviceProp(), + nccl_kernel->Stream(ctx), + nccl_kernel->GetCublasHandle(ctx), + perm, *gathered_buffer, *output) == Status::OK()); + } +} + +std::unique_ptr FuncAllGather( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const Tensor* input, + const int64_t group_size, + const int64_t axis) { + ORT_ENFORCE(group_size >= 0); + ORT_ENFORCE(axis >= 0); + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc) == Status::OK()); + TensorShape output_shape(input->Shape()); + output_shape[axis] = group_size * output_shape[axis]; + auto output = Tensor::Create(input->DataType(), output_shape, alloc); + FuncAllGather(nccl_kernel, ctx, input, group_size, axis, output.get()); + return output; +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h index 24df69ea50224..7fc26e6be57b9 100644 --- a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h +++ b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h @@ -6,7 +6,12 @@ #include "core/providers/cuda/cuda_kernel.h" #if defined(ORT_USE_NCCL) +#include +#include +#include +#include #include +#include #endif namespace onnxruntime { @@ -44,6 +49,10 @@ class NcclKernel : public ::onnxruntime::cuda::CudaKernel { public: explicit NcclKernel(const OpKernelInfo& info); + ncclComm_t Comm() const { + return nccl_->Comm(); + } + protected: NcclContext* nccl_ = nullptr; }; @@ -81,6 +90,27 @@ class AllToAll final : public NcclKernel { int64_t group_size_ = -1; }; +Status FuncAllReduce( + ncclComm_t comm, + cudaStream_t stream, + const Tensor* input, + Tensor* output); + +void FuncAllGather( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const Tensor* input, + const int64_t group_size, + const int64_t axis, + Tensor* output); + +std::unique_ptr FuncAllGather( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const Tensor* input, + const int64_t group_size, + const int64_t axis); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding.cc b/onnxruntime/contrib_ops/cuda/collective/sharding.cc new file mode 100644 index 0000000000000..d9f2f3c1bcbca --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/sharding.cc @@ -0,0 +1,219 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding.h" +#include "mpi_include.h" +#include "sharding_spec.h" + +#include "core/providers/cpu/tensor/slice.h" +#include "core/providers/cuda/tensor/slice.h" +#include "core/providers/cuda/math/matmul.h" +#include "core/providers/cuda/tensor/transpose.h" +#include "core/providers/cuda/cuda_check_memory.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +void GatherTensor( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const NcclKernel* nccl_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const TensorPartitionSpec& spec, + const Tensor* tensor, + Tensor* gathered) { + const int64_t shard_axis = spec.GetPartitionAxis(); + const int64_t shard_count = spec.GetPartitionCount(shard_axis); + + FuncAllGather( + nccl_kernel, + ctx, + tensor, + shard_count, + shard_axis, + gathered); +} + +std::unique_ptr GatherTensor( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const NcclKernel* nccl_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const TensorPartitionSpec& spec, + const Tensor* tensor) { + const int64_t shard_axis = spec.GetPartitionAxis(); + const int64_t shard_count = spec.GetPartitionCount(shard_axis); + TensorShape gathered_shape(tensor->Shape()); + gathered_shape[shard_axis] *= shard_count; + + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc) == Status::OK()); + auto gathered = Tensor::Create(tensor->DataType(), gathered_shape, alloc); + + FuncAllGather( + nccl_kernel, + ctx, + tensor, + shard_count, + shard_axis, + gathered.get()); + + return gathered; +} + +void ShardTensor( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const NcclKernel* nccl_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const TensorPartitionSpec& spec, + const int64_t device_id, + const Tensor* tensor, + Tensor* shard_tensor) { + const int64_t shard_axis = spec.GetPartitionAxis(); + const int64_t shard_count = spec.GetPartitionCount(shard_axis); + TensorShape shard_shape = ComputeShardShape( + tensor->Shape(), + shard_axis, + shard_count); + const int64_t shard_dim = shard_shape[shard_axis]; + const std::vector starts = {shard_dim * device_id}; + const std::vector ends = {shard_dim * (device_id + 1)}; + const std::vector axes = {shard_axis}; + const std::vector steps = {1}; + + ORT_ENFORCE(FuncSlice( + nccl_kernel, + ctx, + tensor, + starts, + ends, + axes, + steps, + shard_tensor) == Status::OK()); +} + +std::unique_ptr ShardTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& spec, + const int64_t device_id, + const Tensor* tensor) { + // Shard all-replica tensor per spec. + + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc) == Status::OK()); + + TensorShape shard_shape = ComputeShardShape( + tensor->Shape(), + spec.GetPartitionAxis(), + spec.GetPartitionCount(spec.GetPartitionAxis())); + auto shard_buffer = Tensor::Create(tensor->DataType(), shard_shape, alloc); + + // Shard with pre-allocated buffer. + ShardTensor( + nccl_kernel, + ctx, + spec, + device_id, + tensor, + shard_buffer.get()); + + return shard_buffer; +} + +void ReshardTensor( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const NcclKernel* nccl_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const TensorPartitionSpec& src_spec, + const TensorPartitionSpec& dst_spec, + const int64_t device_id, + const Tensor* src, + Tensor* dst) { + if (src_spec.HasShard() && dst_spec.HasNoShard()) { + GatherTensor( + nccl_kernel, + ctx, + src_spec, + src, + dst); + return; + } else if (src_spec.HasNoShard() && dst_spec.HasShard()) { + ShardTensor( + nccl_kernel, + ctx, + dst_spec, + device_id, + src, + dst); + } else if (src_spec.HasShard() && dst_spec.HasShard()) { + int64_t src_axis = src_spec.GetPartitionAxis(); + int64_t dst_axis = dst_spec.GetPartitionAxis(); + ORT_ENFORCE(src_axis != dst_axis, "No reshard is needed. Don't call this function."); + + auto all_replica_buffer = GatherTensor( + nccl_kernel, + ctx, + src_spec, + src); + + ShardTensor( + nccl_kernel, + ctx, + dst_spec, + device_id, + all_replica_buffer.get(), + dst); + } else { + ORT_THROW("Not supported yet. Probably resharding is not needed."); + } +} + +std::unique_ptr ReshardTensor( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const NcclKernel* nccl_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const TensorPartitionSpec& src_spec, + const TensorPartitionSpec& dst_spec, + const int64_t device_id, + const Tensor* src) { + // Implement ReshardTensor but returning a unique_ptr to Tensor instead. + const auto origin_shape = ComputeOriginShape(src->Shape(), src_spec); + const auto dst_shape = ComputeShardShape(origin_shape, dst_spec); + ORT_ENFORCE(CanShard(origin_shape, dst_spec), "Cannot shard tensor. Shape:", origin_shape, ", sharding spec: ", dst_spec.ToString()); + + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc) == Status::OK()); + auto dst = Tensor::Create(src->DataType(), dst_shape, alloc); + ReshardTensor( + nccl_kernel, + ctx, + src_spec, + dst_spec, + device_id, + src, + dst.get()); + return dst; +} + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding.h b/onnxruntime/contrib_ops/cuda/collective/sharding.h new file mode 100644 index 0000000000000..497826160aaab --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/sharding.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "nccl_kernels.h" + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +void GatherTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& spec, + const Tensor* tensor, + Tensor* gathered); + +std::unique_ptr GatherTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& spec, + const Tensor* tensor); + +void ShardTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& spec, + const int64_t device_id, + const Tensor* tensor, + Tensor* shard_tensor); + +std::unique_ptr ShardTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& spec, + const int64_t device_id, + const Tensor* tensor); + +void ReshardTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& src_spec, + const TensorPartitionSpec& dst_spec, + const int64_t device_id, + const Tensor* src, + Tensor* dst); + +std::unique_ptr ReshardTensor( + const NcclKernel* nccl_kernel, + OpKernelContext* ctx, + const TensorPartitionSpec& src_spec, + const TensorPartitionSpec& dst_spec, + const int64_t device_id, + const Tensor* src); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc new file mode 100644 index 0000000000000..f1d399077e37b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc @@ -0,0 +1,193 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" + +#include "core/common/common.h" +#include "core/common/gsl.h" +#include "core/framework/tensor_shape.h" + +#include +#include +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +void ValidateAxisIndex(const int64_t axis, const int64_t rank) { + int64_t adjusted_axis = axis; + if (axis < 0) { + adjusted_axis = axis + rank; + } else { + adjusted_axis = axis; + } + ORT_ENFORCE(adjusted_axis >= 0 && adjusted_axis < rank, "axis,", axis, ", should be in [", -rank, ",", rank, ")."); +} + +DeviceMesh CreateDeviceMesh( + std::vector device_mesh_shape, + std::vector device_mesh_elements) { + DeviceMesh device_mesh; + device_mesh.device_mesh_shape = device_mesh_shape; + device_mesh.device_mesh_elements = device_mesh_elements; + return device_mesh; +} + +TensorPartitionSpec CreateTensorPartitionSpec(std::string spec_string, std::vector device_mesh_shape, std::vector device_mesh_elements) { + // "S[0]R" + std::vector axis_specs; + size_t dim_index = 0; + size_t token_index = 0; + while (token_index < spec_string.size()) { + char token = spec_string.at(token_index); + if (token == 'R') { + AxisPartitionSpec axis_spec = AxisPartitionSpec::CreateReplica(); + axis_specs.push_back(axis_spec); + ++token_index; + ++dim_index; + } else if (token == 'S') { + std::stringstream ss; + // Next should be "[". + ++token_index; + char left_bracket = spec_string.at(token_index); + ORT_ENFORCE(left_bracket == '[', "Invalid partition token: ", left_bracket, " in ", spec_string); + // Move to digit part. + ++token_index; + while (spec_string.at(token_index) != ']') { + // Now token_index should points to the first digit of + // axis index. + char digit = spec_string.at(token_index); + ORT_ENFORCE(std::isdigit(digit), "Invalid partition token: ", token, " in ", spec_string); + ss << digit; + // Loaded a digit. Go to next token. + ++token_index; + } + int device_mesh_index = 0; + ss >> device_mesh_index; + AxisPartitionSpec axis_spec = AxisPartitionSpec::CreateShard(device_mesh_index); + axis_specs.push_back(axis_spec); + // Skip "]". + char right_bracket = spec_string.at(token_index); + ORT_ENFORCE(right_bracket == ']', "Invalid partition token: ", token, " in ", spec_string); + ++token_index; + } else { + throw std::invalid_argument("Invalid partition token: " + token); + } + } + DeviceMesh device_mesh = CreateDeviceMesh(device_mesh_shape, device_mesh_elements); + return TensorPartitionSpec::Create(axis_specs, device_mesh); +} + +TensorPartitionSpec CreateTensorShardSpec( + const DeviceMesh& device_mesh, + int64_t device_mesh_axis, + int64_t shard_axis, + int64_t tensor_rank) { + if (shard_axis < 0) { + shard_axis += tensor_rank; + } + std::vector axis_specs; + for (int64_t i = 0; i < tensor_rank; ++i) { + if (i == shard_axis) { + axis_specs.push_back(AxisPartitionSpec::CreateShard(device_mesh_axis)); + } else { + axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + } + } + return TensorPartitionSpec::Create(axis_specs, device_mesh); +} + +TensorShape ComputeOriginShape(const TensorShape& shard_shape, const TensorPartitionSpec& spec) { + ORT_ENFORCE(gsl::narrow(shard_shape.NumDimensions()) == spec.Rank(), "Shard shape and spec rank mismatch."); + if (spec.HasNoShard()) { + return shard_shape; + } + TensorShape shape(shard_shape); + const int64_t axis = spec.GetPartitionAxis(); + shape[axis] *= spec.GetPartitionCount(axis); + return shape; +} + +TensorShape ComputeShardShape(const TensorShape& shape, const TensorPartitionSpec& spec) { + ORT_ENFORCE(gsl::narrow(shape.NumDimensions()) == spec.Rank(), "Shape and spec rank mismatch."); + TensorShape shard_shape(shape); + if (spec.HasNoShard()) { + return shard_shape; + } + const int64_t axis = spec.GetPartitionAxis(); + shard_shape[axis] /= spec.GetPartitionCount(axis); + return shard_shape; +} + +TensorShape ComputeShardShape(const TensorShape source_shape, int64_t shard_axis, int64_t shard_count) { + if (shard_axis < 0) { + shard_axis += gsl::narrow(source_shape.NumDimensions()); + } + TensorShape shard_shape(source_shape); + ORT_ENFORCE(shard_axis < gsl::narrow(source_shape.NumDimensions()), "Shard axis must be less than the number of dimensions of the source tensor."); + ORT_ENFORCE(source_shape[shard_axis] % shard_count == 0, "Number of shards must be divisible by sharded axis' dimension."); + shard_shape[shard_axis] = source_shape[shard_axis] / shard_count; + return shard_shape; +} + +std::tuple NormalizeShapes(const TensorShape& left, const TensorShape& right) { + if (left.NumDimensions() > right.NumDimensions()) { + std::vector right_vector(right.NumDimensions(), 0); + right.CopyDims(right_vector.data(), right.NumDimensions()); + // Fill 1's to right shape. E.g., + // left: [1, 2, 3, 4], right: [5, 6, 7] -> left: [1, 2, 3, 4], right: [1, 5, 6, 7] + right_vector.insert(right_vector.begin(), left.NumDimensions() - right.NumDimensions(), 1); + return std::make_tuple(left, TensorShape(right_vector)); + } else if (left.NumDimensions() < right.NumDimensions()) { + std::vector left_vector(left.NumDimensions(), 0); + left.CopyDims(left_vector.data(), left.NumDimensions()); + // Fill 1's to left shape. E.g., + // left: [1, 2, 3], right: [4, 5, 6, 7] -> left: [1, 2, 3, 1], right: [4, 5, 6, 7] + left_vector.insert(left_vector.begin(), right.NumDimensions() - left.NumDimensions(), 1); + return std::make_tuple(TensorShape(left_vector), TensorShape(right)); + } else { + return std::make_tuple(TensorShape(left), TensorShape(right)); + } +} + +std::tuple NormalizeTensorPartitionSpecs( + const TensorPartitionSpec& left, const TensorPartitionSpec& right) { + // TODO: Make it to modify left and right instead of returning new values. + if (left.axis_specs.size() > right.axis_specs.size()) { + auto new_right = TensorPartitionSpec::Create(right.axis_specs, right.device_mesh); + new_right.axis_specs.insert(new_right.axis_specs.begin(), left.axis_specs.size() - right.axis_specs.size(), AxisPartitionSpec::CreateReplica()); + return std::make_tuple(left, new_right); + } else if (left.axis_specs.size() < right.axis_specs.size()) { + auto new_left = TensorPartitionSpec::Create(left.axis_specs, left.device_mesh); + new_left.axis_specs.insert(new_left.axis_specs.begin(), right.axis_specs.size() - left.axis_specs.size(), AxisPartitionSpec::CreateReplica()); + return std::make_tuple(new_left, right); + } else { + return std::make_tuple(left, right); + } +} + +bool CanShard(const TensorShape& shape, const TensorPartitionSpec& spec) { + if (spec.HasNoShard()) { + return true; + } + if (gsl::narrow(shape.NumDimensions()) != spec.Rank()) { + return false; + } + const int64_t axis = spec.GetPartitionAxis(); + if (axis < 0 || gsl::narrow(axis) >= shape.NumDimensions()) { + return false; + } + if (shape[axis] % spec.GetPartitionCount(axis) != 0) { + return false; + } + return true; +} + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h new file mode 100644 index 0000000000000..13982ee7711cf --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h @@ -0,0 +1,374 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/common.h" +#include "core/framework/tensor_shape.h" + +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +class DeviceMesh { + public: + // [Device Mesh and Tensor Sharding for Tensor Parallel] + // Device mesh is a tensor of device indices. + // A tensor can then be partitioned along specific mesh axes. + // + // Assume we have 4 GPUs indexed by 0, 1, 2, and 3. + // Let's consider some examples. + // 1. 1D device mesh [0, 1, 2, 3]. In this case, + // device_mesh_shape is [4] and device_mesh_elements + // is [0, 1, 2, 3]. + // If we want to shard a 2-D tensor along its axis 1, the + // corresponding sharding spec is a string "RS[0]". + // 2. 2D device mesh [[0, 1], [2, 3]]. In this case, + // device_mesh_shape is [2, 2] and device_mesh_elements + // is [0, 1, 2, 3]. + // If we want to shard a 2-D tensor's + // rows along mesh axis 1 and + // columns along mesh axis 0, the + // corresponding sharding spec is a string "S[1]S[0]". + // If that 2-D tensor's value is np.array([[5, 6], [7, 8]]), + // GPU 0/1/2/3 owns 5/7/6/8. Below is a visualization the sharding + // proccess. + // - Start with a 2-D device mesh [[0, 1], [2, 3]] and + // a 2-D tensor [[5, 6], [7, 8]] + // - GPU: [[0, 1], [2, 3]], Tensor: [[5, 6], [7, 8]] + // - Split GPU mesh along axis 1 and tensor along + // axis 0 for "S[1]" in "S[1]S[0]" + // - GPU: [[0], [2]], Tensor: [[5, 6]] + // GPU: [[1], [3]], Tensor: [[7, 8]] + // - Split GPU mesh along axis 0 and tensor along + // axis 1 for "S[0]" in "S[1]S[0]" + // - GPU: [[0]], Tensor: [[5]] + // - GPU: [[2]], Tensor: [[6]] + // - GPU: [[1]], Tensor: [[7]] + // - GPU: [[3]], Tensor: [[8]] + + // Actual shape of device mesh represented by `device_mesh_elements`. + std::vector device_mesh_shape; + + // Flattened device mesh. + std::vector device_mesh_elements; + + // Helper to debug and generate error message; e.g., + // "DeviceMesh{Shape: [2,2,], Elements: [0,1,2,3,]}". + std::string ToString() const { + std::ostringstream os; + os << "DeviceMesh{Shape: ["; + for (const auto& shape : device_mesh_shape) + os << shape << ","; + os << "], Elements: ["; + for (const auto& element : device_mesh_elements) + os << element << ","; + os << "]}"; + return os.str(); + } + + // Call this in GDB to visualize the mesh. + void Print() const { + std::cout << ToString() << std::endl; + } +}; + +class AxisPartitionSpec { + // [Device Mesh and Tensor Sharding for Tensor Parallel] + // This class is the in-memory representation of + // 1. if a tensor is sharded or not (aka replica), and + // 2. which tensor axis is shard by which device mesh axis. + // Let's consider sharding 2-D tensor along column axis on + // device mesh [0, 1] as an example. + // The required sharding spec RS[0] can be represented by + // - AxisPartitionSpec(Condition::Replica, -1) + // - AxisPartitionSpec(Condition::Shard, 0) + public: + // Status of a tensor axis. + // A tensor axis can be either sharded or replicated + // along a device mesh axis. + enum class Condition { Replica, + Shard }; + + // This field tells if a tensor axis is sharded or not. + Condition cond; + + // If a tensor axis is sharded, this field tells which device + // mesh axis to distribute the shards along. + // If a tensor axis is not sharded, this field is ignored. + int device_mesh_axis; + + // A helper to construct a replica spec for a tensor axis. + static AxisPartitionSpec CreateReplica() { + return AxisPartitionSpec(Condition::Replica, -1); + } + + // A helper to construct a sharding spec for a tensor axis. + // This tensor axis is sharded along `device_mesh_axis` in device mesh. + static AxisPartitionSpec CreateShard(int device_mesh_axis) { + return AxisPartitionSpec(Condition::Shard, device_mesh_axis); + } + + // A normal ctor. + // TODO(wechi): Consider to hide it and revise the `public` members/functions + // exposed to the user. + AxisPartitionSpec(Condition cond_, int device_mesh_axis_) : device_mesh_axis(device_mesh_axis_), cond(cond_) {} + + // Helper to debug and generate error message; e.g., + // "RS[0]". + std::string ToString() const { + std::ostringstream os; + os << (cond == Condition::Replica ? "R" : "S"); + if (cond == Condition::Shard) os << "[" << device_mesh_axis << "]"; + return os.str(); + } + + // Call this in GDB to visualize the spec. + void Print() const { + std::cout << ToString() << std::endl; + } +}; + +// Return true if `axis` is a valid axis index for a tensor of rank `rank`. +// Negative `axis` is allowed (e.g., -1 for the last axis). +void ValidateAxisIndex(const int64_t axis, const int64_t rank); + +class TensorPartitionSpec { + // [Device Mesh and Tensor Sharding for Tensor Parallel] + // TensorPartitionSpec holds a collection of AxisPartitionSpec and an + // associated DeviceMesh. It is responsible for determining how a tensor + // should be partitioned across a device mesh. + // + // Example 1: RS[0] + // In this scenario, `axis_specs` would contain two `AxisPartitionSpec` objects. + // - The first object is a Replica, denoting that the first axis of the tensor is + // not sharded but is instead replicated. + // - The second object is a Shard along the 0-th axis of the device mesh. It denotes + // that the second axis of the tensor is sharded along the first axis of the + // device mesh. + // + // Example 2: S[0]RR + // In this scenario, `axis_specs` would contain three `AxisPartitionSpec` objects. + // - The first object is a Shard along the 0-th axis of the device mesh, indicating + // that the first axis of the tensor is sharded along the first axis of the + // device mesh. + // - The second and third objects are Replicas, indicating that the second and third + // axes of the tensor are not sharded but are instead replicated. + public: + // axis_specs[i]: AxisPartitionSpec for tensor axis i. For a 2-D tensor, + // axis_specs[0] is for row axis and axis_specs[1] is for + // column axis. axis_specs[i].device_mesh_axis = j means that + // tensor axis i is sharded along device mesh axis j. + std::vector axis_specs; + + // device_mesh: DeviceMesh for sharding the associated tensor. + // Read [Device Mesh and Tensor Sharding for Tensor Parallel] in DeviceMesh's comment. + DeviceMesh device_mesh; + + // Replacement of ctor. + static TensorPartitionSpec Create( + const std::vector& axis_specs, const DeviceMesh& device_mesh) { + TensorPartitionSpec spec; + spec.axis_specs = axis_specs; + spec.device_mesh = device_mesh; + return spec; + } + + // Copy-construct `spec` but with all tensor axes replicated. + // The new spec have the same number of axis specs and the same device mesh. + static TensorPartitionSpec CreateAllReplica( + const TensorPartitionSpec& spec) { + TensorPartitionSpec new_spec = spec; + new_spec.axis_specs[spec.GetPartitionAxis()] = AxisPartitionSpec::CreateReplica(); + return new_spec; + } + + // TODO(wechi): Create a halper to copy-construct a new spec with different sharding axis. + // static TensorPartitionSpec CreateReshard( + // const TensorPartitionSpec& spec, int64_t new_shard_axis) { + // } + + // Helper to debug and generate error message; e.g., + // "TensorPartitionSpec{RS[0], Device Mesh: DeviceMesh{Shape: [4,], Elements: [0,1,2,3,]}}". + std::string ToString() const { + std::ostringstream os; + os << "TensorPartitionSpec{"; + for (const auto& spec : axis_specs) + os << spec.ToString(); + os << ", DeviceMesh: " << device_mesh.ToString() << "}"; + return os.str(); + } + + // Call this in GDB to visualize the spec. + void Print() const { + std::cout << ToString() << std::endl; + } + + // Return true if at least one tensor axis is sharded. + // Otherwise, return false. + bool HasShard() const { + for (const auto& spec : axis_specs) + if (spec.cond == AxisPartitionSpec::Condition::Shard) return true; + return false; + } + + // Return true if no tensor axis is sharded. + // Otherwise, return false. + bool HasNoShard() const { + return !HasShard(); + } + + // Return true if the only sharded tensor axis is `axis`. + // Otherwise, return false. + bool OnlyShardAxis(int64_t axis) const { + ValidateAxisIndex(axis, Rank()); + if (axis < 0) { + axis += Rank(); + } + bool answer = true; + for (int64_t i = 0; i < Rank(); ++i) { + if (i == axis && axis_specs[i].cond != AxisPartitionSpec::Condition::Shard) { + answer = false; + } else if (i != axis && axis_specs[i].cond == AxisPartitionSpec::Condition::Shard) { + answer = false; + } + } + return answer; + } + + // Rank of the owing tensor of this spec. + int64_t Rank() const { + return gsl::narrow(axis_specs.size()); + } + + // Return the number of sharded tensor axes. + // Currently we only support one sharded tensor axis, so + // we may assert the returned value is 1 in related APIs. + int64_t CountShardingAxes() const { + int64_t count = 0; + for (const auto& spec : axis_specs) + if (spec.cond == AxisPartitionSpec::Condition::Shard) count++; + return count; + } + + // Return the AxisPartitionSpec for `axis`-th tensor axis. + const AxisPartitionSpec& GetAxisSpec(int64_t axis) const { + ValidateAxisIndex(axis, Rank()); + if (axis < 0) { + axis += Rank(); + } + return axis_specs.at(axis); + } + + // Get the first sharded tensor axis' sharding spec. + const AxisPartitionSpec& GetPartitionAxisSpec() const { + // TODO: support multiple sharding axes. + ORT_ENFORCE(CountShardingAxes() == 1, "TensorPartitionSpec must have exactly one sharding axis."); + return GetAxisSpec(GetPartitionAxis()); + } + + // Get the first sharded tensor axis' index. + // E.g., spec "RS[0]" should return 1, spec "S[0]R" should return 0, spec "RR" should return -1. + // Returned value -1 means no sharded tensor axis. + int64_t GetPartitionAxis() const { + // TODO: support multiple sharding axes. + ORT_ENFORCE(CountShardingAxes() == 1, "TensorPartitionSpec must have exactly one sharding axis."); + for (int64_t i = 0; i < gsl::narrow(axis_specs.size()); ++i) { + if (axis_specs[i].cond == AxisPartitionSpec::Condition::Shard) { + return i; + } + } + return -1; + } + + // Similarly to GetPartitionAxis(), but returns the negative index of the first sharded tensor axis. + // E.g., spec "RS[0]" should return -1, spec "S[0]R" should return -2, and spec "RR" should return 0. + // Returned value 0 means no sharded tensor axis. + int64_t GetNegativePartitionAxis() const { + // TODO: support multiple sharding axes. + ORT_ENFORCE(CountShardingAxes() == 1, "TensorPartitionSpec must have exactly one sharding axis."); + for (int64_t i = 0; i < gsl::narrow(axis_specs.size()); ++i) { + if (axis_specs[i].cond == AxisPartitionSpec::Condition::Shard) { + return i - axis_specs.size(); + } + } + return 0; + } + + // Return the number of shards along the first sharded tensor axis. + // This value matches the number of devices along the associated mesh axis. + // Return 1 if there is no sharding. + int64_t GetPartitionCount(int64_t axis) const { + ValidateAxisIndex(axis, Rank()); + auto axis_spec = GetAxisSpec(axis); + if (axis_spec.cond == AxisPartitionSpec::Condition::Replica) { + return 1; + } else { + return device_mesh.device_mesh_shape.at(axis_spec.device_mesh_axis); + } + } +}; + +DeviceMesh CreateDeviceMesh( + std::vector device_mesh_shape, + std::vector device_mesh_elements); + +TensorPartitionSpec CreateTensorPartitionSpec( + std::string spec_string, + std::vector device_mesh_shape, + std::vector device_mesh_elements); + +TensorPartitionSpec CreateTensorShardSpec( + const DeviceMesh& device_mesh, + int64_t device_mesh_axis, + int64_t shard_axis, + int64_t tensor_rank); + +// Return the shape of the original tensor before sharding. +// E.g., assume tensor shard's shape is [5, 7] and sharding spec is "S[0]R" +// with 1-D device mesh [0, 1, 2]. +// This function returns [15, 7]. +// +// `shard_shape`: the shape of a shard. +// `spec`: the sharding spec of the original tensor. +TensorShape ComputeOriginShape(const TensorShape& shard_shape, const TensorPartitionSpec& spec); + +// Return the shape of a shard. +// E.g., assume tensor's shape is [15, 7] and sharding spec is "S[0]R" +// with 1-D device mesh [0, 1, 2]. +// This function returns [5, 7]. +// +// `shape`: the shape of the original tensor. +// `spec`: the sharding spec of the original tensor. +TensorShape ComputeShardShape(const TensorShape& shape, const TensorPartitionSpec& spec); + +// Similarly to ComputeShardShape(), but takes a shard axis and counts of all tensor shards +// instead of a spec. +TensorShape ComputeShardShape(const TensorShape source_shape, int64_t shard_axis, int64_t shard_count); + +// Prepend 1's to `shape` to make `left` and `right` have the same rank. +// E.g., if `left` is [3, 7] and `right` is [5, 6, 7], this function returns [1, 3, 7] and [5, 6, 7]. +std::tuple NormalizeShapes(const TensorShape& left, const TensorShape& right); + +// Prepend `R` (aks replicating axis) to `spec` to make `left` and `right` have the same rank. +// E.g., if `left` is S[0]R and `right` is `RRR`, this function returns `RS[0]R` and `RRR`. +std::tuple NormalizeTensorPartitionSpecs( + const TensorPartitionSpec& left, const TensorPartitionSpec& right); + +// Return true if `shape` can be sharded according to `spec`. +// Otherwise, return false. +// Note that an axis is shardable along a device mesh axis only if +// the dimension of the axis is divisible by the number of devices along the device mesh axis. +bool CanShard(const TensorShape& shape, const TensorPartitionSpec& spec); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 86c1cb93e8b6f..70bc1cd6bf4dc 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -149,6 +149,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Shru class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllReduce); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllGather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, AllToAll); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedMatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedMatMul); #endif template <> @@ -302,6 +305,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 9e63e0d5e83f6..84eed7fae6ac1 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -79,6 +79,32 @@ void RegisterCollectiveOps() { .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateShapeAndTypeFromFirstInput(ctx); }); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedMatMul) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("device_mesh_elements", + "", + AttributeProto::INTS) + .Attr("device_mesh_shape", + "", + AttributeProto::INTS) + .Attr("input_shard_specs", + "The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "The sharding spec of \"Y\"; e.g., \"RRR\" if Y is not sharded.", + AttributeProto::STRINGS) + .Input(0, "A", "N-dimensional matrix A", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input(1, "B", "N-dimensional matrix B", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Output(0, "Y", "Matrix multiply results from A * B", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint( + "T", + { + "tensor(float16)", + "tensor(float)", + }, + "Constrain input and output types to float tensors."); } } // namespace contrib diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index ddcc04cf4a45c..b40f0edb2a8e1 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -115,6 +115,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU { Status PrepareOutputShape(const Tensor* indices, const int64_t depth_val, const int64_t axis, int64_t& prefix_dim_size, int64_t& suffix_dim_size, TensorShapeVector& output_shape) override { return onnxruntime::PrepareOutputShape(indices, depth_val, axis, prefix_dim_size, suffix_dim_size, output_shape); } // From cpu/tensor/slice.h (direct) + Status SliceBase__FlattenOutputDims(gsl::span input_dimensions, gsl::span output_dims, + TensorShapeVector& starts, TensorShapeVector& ends, TensorShapeVector& steps, + TensorShapeVector*& p_flattened_input_dims, TensorShapeVector*& p_flattened_output_dims) override { + return SliceBase::FlattenOutputDims(input_dimensions, output_dims, starts, ends, steps, p_flattened_input_dims, p_flattened_output_dims); + } + Status SliceBase__PrepareForCompute(gsl::span raw_starts, gsl::span raw_ends, gsl::span raw_axes, diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 7d4620f0039eb..7dc80f44b53cd 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -63,6 +63,10 @@ struct ProviderHostCPU { virtual Status PrepareOutputShape(const Tensor* indices, const int64_t depth_val, const int64_t axis, int64_t& prefix_dim_size, int64_t& suffix_dim_size, TensorShapeVector& output_shape) = 0; // From cpu/tensor/slice.h + virtual Status SliceBase__FlattenOutputDims(gsl::span input_dimensions, gsl::span output_dims, + TensorShapeVector& starts, TensorShapeVector& ends, TensorShapeVector& steps, + TensorShapeVector*& p_flattened_input_dims, TensorShapeVector*& p_flattened_output_dims) = 0; + virtual Status SliceBase__PrepareForCompute(gsl::span raw_starts, gsl::span raw_ends, gsl::span raw_axes, diff --git a/onnxruntime/core/providers/cpu/tensor/slice.cc b/onnxruntime/core/providers/cpu/tensor/slice.cc index a8cb74a62e02d..e0cd74343b83d 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.cc +++ b/onnxruntime/core/providers/cpu/tensor/slice.cc @@ -76,9 +76,9 @@ ONNX_CPU_OPERATOR_KERNEL( // e.g. if input shape is { 2, 2, 2, 2, 2 }, output shape is { 2, 2, 1, 2, 2 }, // and the 'steps' value for all dims is 1 except dim-2, then the input shape is coalesced to { 4, 2, 4 } // and the output shape is coalesced to { 4, 1, 4 }. -static void FlattenOutputDims(gsl::span input_dimensions, gsl::span output_dims, - TensorShapeVector& starts, TensorShapeVector& ends, TensorShapeVector& steps, - TensorShapeVector*& p_flattened_input_dims, TensorShapeVector*& p_flattened_output_dims) { +Status SliceBase::FlattenOutputDims(gsl::span input_dimensions, gsl::span output_dims, + TensorShapeVector& starts, TensorShapeVector& ends, TensorShapeVector& steps, + TensorShapeVector*& p_flattened_input_dims, TensorShapeVector*& p_flattened_output_dims) { size_t cur = 0; size_t nxt = 0; while (true) { @@ -131,6 +131,8 @@ static void FlattenOutputDims(gsl::span input_dimensions, gsl::sp ends.resize(cur); steps.resize(cur); } + + return Status::OK(); } // Slice V1-9 & DynamicSlice @@ -138,9 +140,9 @@ Status SliceBase::PrepareForCompute(gsl::span raw_starts, gsl::sp gsl::span raw_axes, SliceOp::PrepareForComputeMetadata& compute_metadata) { ORT_RETURN_IF_ERROR(SliceOp::PrepareForComputeHelper(raw_starts, raw_ends, raw_axes, compute_metadata)); - FlattenOutputDims(compute_metadata.input_dimensions_, compute_metadata.output_dims_, compute_metadata.starts_, - compute_metadata.ends_, compute_metadata.steps_, compute_metadata.p_flattened_input_dims_, - compute_metadata.p_flattened_output_dims_); + ORT_RETURN_IF_ERROR(FlattenOutputDims(compute_metadata.input_dimensions_, compute_metadata.output_dims_, compute_metadata.starts_, + compute_metadata.ends_, compute_metadata.steps_, compute_metadata.p_flattened_input_dims_, + compute_metadata.p_flattened_output_dims_)); return Status::OK(); } @@ -149,9 +151,9 @@ Status SliceBase::PrepareForCompute(gsl::span raw_starts, gsl::sp gsl::span raw_axes, gsl::span raw_steps, SliceOp::PrepareForComputeMetadata& compute_metadata) { ORT_RETURN_IF_ERROR(SliceOp::PrepareForComputeHelper(raw_starts, raw_ends, raw_axes, raw_steps, compute_metadata)); - FlattenOutputDims(compute_metadata.input_dimensions_, compute_metadata.output_dims_, compute_metadata.starts_, - compute_metadata.ends_, compute_metadata.steps_, compute_metadata.p_flattened_input_dims_, - compute_metadata.p_flattened_output_dims_); + ORT_RETURN_IF_ERROR(FlattenOutputDims(compute_metadata.input_dimensions_, compute_metadata.output_dims_, compute_metadata.starts_, + compute_metadata.ends_, compute_metadata.steps_, compute_metadata.p_flattened_input_dims_, + compute_metadata.p_flattened_output_dims_)); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/tensor/slice.h b/onnxruntime/core/providers/cpu/tensor/slice.h index 28e76aca4ea21..1503a87931bcf 100644 --- a/onnxruntime/core/providers/cpu/tensor/slice.h +++ b/onnxruntime/core/providers/cpu/tensor/slice.h @@ -38,6 +38,10 @@ class SliceBase { TensorShapeVector& input_axes, TensorShapeVector& input_steps); + static Status FlattenOutputDims(gsl::span input_dimensions, gsl::span output_dims, + TensorShapeVector& starts, TensorShapeVector& ends, TensorShapeVector& steps, + TensorShapeVector*& p_flattened_input_dims, TensorShapeVector*& p_flattened_output_dims); + protected: SliceBase(const OpKernelInfo& info, bool dynamic = false) : dynamic_(dynamic) { diff --git a/onnxruntime/core/providers/cuda/math/matmul.cc b/onnxruntime/core/providers/cuda/math/matmul.cc index 899d506f840a2..e4c37c52a1780 100644 --- a/onnxruntime/core/providers/cuda/math/matmul.cc +++ b/onnxruntime/core/providers/cuda/math/matmul.cc @@ -119,6 +119,161 @@ Status MatMul::ComputeInternal(OpKernelContext* ctx) const { return ComputeDefault(ctx, helper); } +template +Status FuncMatMul( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* A, + const Tensor* B, + float alpha, + bool trans_A, + bool trans_B, + bool trans_batch_A, + bool trans_batch_B, + Tensor* Y) { + typedef typename ToCudaType::MappedType CudaT; + + // Ignore the transpose flag if rank of input being 1. + // Be noted: numpy.transpose on vector does not change anything. + if (A->Shape().NumDimensions() == 1) { + trans_A = false; + } + if (B->Shape().NumDimensions() == 1) { + trans_B = false; + } + + const CudaT cuda_alpha = ToCudaType::FromFloat(alpha); + const CudaT cuda_zero = ToCudaType::FromFloat(0.0f); + + cublasOperation_t cuda_trans_A = trans_A ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t cuda_trans_B = trans_B ? CUBLAS_OP_T : CUBLAS_OP_N; + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR( + helper.Compute(A->Shape(), B->Shape(), trans_A, trans_B, trans_batch_A, trans_batch_B, false)); + const int lda = helper.Lda(trans_A); + const int ldb = helper.Ldb(trans_B); + const int ldc = helper.Ldc(); + int64_t stride_A, stride_B, stride_C, batch_count; + auto& device_prop = cuda_kernel->GetDeviceProp(); + + if (helper.OutputOffsets().size() == 1) { + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + cuda_kernel->GetCublasHandle(ctx), + cuda_trans_B, + cuda_trans_A, + static_cast(helper.N()), + static_cast(helper.M()), + static_cast(helper.K()), + &cuda_alpha, + reinterpret_cast(B->Data()), + ldb, + reinterpret_cast(A->Data()), + lda, + &cuda_zero, + reinterpret_cast(Y->MutableData()), + ldc, + device_prop)); + return Status::OK(); + } else if (CanUseStridedBatchedGemm(A->Shape(), B->Shape(), + trans_A, trans_B, trans_batch_B, trans_batch_B, stride_A, stride_B, stride_C, batch_count)) { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(cuda_kernel->GetCublasHandle(ctx), + cuda_trans_B, + cuda_trans_A, + static_cast(helper.N()), + static_cast(helper.M()), + static_cast(helper.K()), + &cuda_alpha, + reinterpret_cast(B->Data()), + ldb, + stride_B, + reinterpret_cast(A->Data()), + lda, + stride_A, + &cuda_zero, + reinterpret_cast(Y->MutableData()), + ldc, + stride_C, + static_cast(batch_count), + device_prop)); + + return Status::OK(); + } + + // Fill offsets when needed. + helper.FillOffsets(); + CudaKernel::CudaAsyncBuffer A_arrays(cuda_kernel, helper.LeftOffsets().size()); + CudaKernel::CudaAsyncBuffer B_arrays(cuda_kernel, helper.RightOffsets().size()); + CudaKernel::CudaAsyncBuffer Y_arrays(cuda_kernel, helper.OutputOffsets().size()); + MatMulComputeHelper::OffsetToArrays(reinterpret_cast(A->Data()), helper.LeftOffsets(), A_arrays.CpuSpan()); + MatMulComputeHelper::OffsetToArrays(reinterpret_cast(B->Data()), helper.RightOffsets(), B_arrays.CpuSpan()); + MatMulComputeHelper::OffsetToArrays(reinterpret_cast(Y->MutableData()), helper.OutputOffsets(), Y_arrays.CpuSpan()); + ORT_RETURN_IF_ERROR(A_arrays.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(B_arrays.CopyToGpu(ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(Y_arrays.CopyToGpu(ctx->GetComputeStream())); + + // TF32 provides a huge performance gain for training and inference while preserving FP32 levels of accuracy. + // It requires Ampere or newer GPU, and pointers of matrics shall be aligned (ideal alignment is 16-byte). + // Assume that start memory of input/output tensor is aligned, we only check offsets of sub-matrix per batch here. + cublasMath_t mode = (std::is_same::value && device_prop.major >= 8 && helper.IsBatchedGemmAligned()) + ? CUBLAS_TF32_TENSOR_OP_MATH + : CUBLAS_DEFAULT_MATH; + CublasMathModeSetter math_mode_setter(device_prop, cuda_kernel->GetCublasHandle(ctx), mode); + + // note that onnxruntime OrtValue is row major, while cublas is column major, + // so swap left/right operands + CUBLAS_RETURN_IF_ERROR(cublasGemmBatchedHelper( + cuda_kernel->GetCublasHandle(ctx), + cuda_trans_B, + cuda_trans_A, + static_cast(helper.N()), + static_cast(helper.M()), + static_cast(helper.K()), + &cuda_alpha, + B_arrays.GpuPtr(), + ldb, + A_arrays.GpuPtr(), + lda, + &cuda_zero, + Y_arrays.GpuPtr(), + ldc, + static_cast(helper.OutputOffsets().size()), + device_prop)); + return Status::OK(); +} + +template Status FuncMatMul( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const CudaKernel* cuda_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const Tensor* A, + const Tensor* B, + float alpha, + bool trans_A, + bool trans_B, + bool trans_batch_A, + bool trans_batch_B, + Tensor* Y); + +template Status FuncMatMul( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const CudaKernel* cuda_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const Tensor* A, + const Tensor* B, + float alpha, + bool trans_A, + bool trans_B, + bool trans_batch_A, + bool trans_batch_B, + Tensor* Y); + template Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& helper) const { typedef typename ToCudaType::MappedType CudaT; diff --git a/onnxruntime/core/providers/cuda/math/matmul.h b/onnxruntime/core/providers/cuda/math/matmul.h index 5ea7b30777402..26de1044eeb23 100644 --- a/onnxruntime/core/providers/cuda/math/matmul.h +++ b/onnxruntime/core/providers/cuda/math/matmul.h @@ -31,5 +31,23 @@ class MatMul final : public CudaKernel { const bool trans_batch_a_; const bool trans_batch_b_; }; + +template +Status FuncMatMul( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const CudaKernel* cuda_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const Tensor* A, + const Tensor* B, + float alpha, + bool trans_A, + bool trans_B, + bool trans_batch_A, + bool trans_batch_B, + Tensor* Y); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/slice.cc b/onnxruntime/core/providers/cuda/tensor/slice.cc index 440b19bce9fb6..db285ba547b6a 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice.cc +++ b/onnxruntime/core/providers/cuda/tensor/slice.cc @@ -3,6 +3,7 @@ #include "core/providers/cuda/tensor/slice.h" #include "core/providers/cpu/tensor/utils.h" +#include "core/providers/cpu/tensor/slice_helper.h" #include "core/providers/cuda/tensor/slice_impl.h" namespace onnxruntime { @@ -235,5 +236,58 @@ Status Slice::CallSliceImp(size_t element_size, size_t dimension_count, output_shape); } +Status FuncSlice( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const CudaKernel* cuda_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const Tensor* input, + const std::vector& starts, + const std::vector& ends, + const std::vector& axes, + const std::vector& steps, + Tensor* output) { + gsl::span starts_span = gsl::make_span(starts.data(), starts.size()); + gsl::span ends_span = gsl::make_span(ends.data(), ends.size()); + gsl::span axes_span = gsl::make_span(axes.data(), axes.size()); + gsl::span steps_span = gsl::make_span(steps.data(), steps.size()); + const auto& input_shape = input->Shape(); + const auto input_dimensions = input_shape.GetDims(); + + SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions); + + ORT_RETURN_IF_ERROR( + SliceOp::PrepareForComputeHelper(starts_span, ends_span, axes_span, steps_span, compute_metadata)); + + ORT_RETURN_IF_ERROR(SliceBase::FlattenOutputDims(compute_metadata.input_dimensions_, compute_metadata.output_dims_, compute_metadata.starts_, + compute_metadata.ends_, compute_metadata.steps_, compute_metadata.p_flattened_input_dims_, + compute_metadata.p_flattened_output_dims_)); + + TensorShape output_shape(compute_metadata.output_dims_); + + TArray starts_buffer(compute_metadata.starts_); + TArray steps_buffer(compute_metadata.steps_); + TArray input_strides; + TArray output_strides; + + ORT_RETURN_IF_ERROR(SliceCuda::ComputeSliceStrides(input_shape, input_strides, output_strides, compute_metadata)); + + ORT_RETURN_IF_ERROR(SliceImpl( + cuda_kernel->Stream(ctx), + input->DataType()->Size(), + gsl::narrow_cast(input_dimensions.size()), + starts_buffer, + steps_buffer, + input_strides, + output_strides, + input->DataRaw(), + output->MutableDataRaw(), + output_shape.Size())); + + return Status::OK(); +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/slice.h b/onnxruntime/core/providers/cuda/tensor/slice.h index 444e37c2167e8..d5c53611d3421 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice.h +++ b/onnxruntime/core/providers/cuda/tensor/slice.h @@ -38,5 +38,20 @@ class Slice : public CudaKernel, public SliceBase { const TArray& output_strides, OpKernelContext* ctx, const TensorShape& output_shape) const; }; + +Status FuncSlice( + // Use OpKernel and do a pointer cast to unify functional calls with other eps. + // TODO: remove CudaKernel and OpKernelContext. + const CudaKernel* cuda_kernel, + // Do NOT use ctx to access inputs and outputs. + // Inputs and outputs are passed in as function arguments. + OpKernelContext* ctx, + const Tensor* input, + const std::vector& starts, + const std::vector& ends, + const std::vector& axes, + const std::vector& steps, + Tensor* output); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index d6546ccdd9d5d..d118bbe6d510c 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -510,6 +510,13 @@ bool TileOp::IsTileMemcpy(const TensorShape& input_shape, const int64_t* repeats return g_host_cpu.TileOp__IsTileMemcpy(input_shape, repeats, rank, is_batched_memcpy, num_of_elements_per_batch, num_of_copies_per_batch, num_of_batch_copies); } +Status SliceBase::FlattenOutputDims(gsl::span input_dimensions, gsl::span output_dims, + TensorShapeVector& starts, TensorShapeVector& ends, TensorShapeVector& steps, + TensorShapeVector*& p_flattened_input_dims, TensorShapeVector*& p_flattened_output_dims) { + return g_host_cpu.SliceBase__FlattenOutputDims( + input_dimensions, output_dims, starts, ends, steps, p_flattened_input_dims, p_flattened_output_dims); +} + Status SliceBase::PrepareForCompute(gsl::span raw_starts, gsl::span raw_ends, gsl::span raw_axes, diff --git a/onnxruntime/test/python/onnxruntime_test_distributed.py b/onnxruntime/test/python/onnxruntime_test_distributed.py new file mode 100644 index 0000000000000..7f3cbc254969e --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_distributed.py @@ -0,0 +1,317 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import unittest + +import numpy as np +import onnxscript +from mpi4py import MPI +from onnxscript import FLOAT + +import onnxruntime as ort + +MICROSOFT_OPSET = onnxscript.values.Opset(domain="com.microsoft", version=1) +comm = MPI.COMM_WORLD + + +def shard_tensor(X, rank, axis, num_shards): + return np.split(X, num_shards, axis)[rank] + + +class TestDistributedMatMul(unittest.TestCase): + def test_matmul_rs_sr_rr(self): + @onnxscript.script() + def matmul_rs_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: + return MICROSOFT_OPSET.DistributedMatMul( + tensor_x, + tensor_w, + device_mesh_shape=[2], + device_mesh_elements=[0, 1], + input_shard_specs=["RS[0]", "S[0]R"], + output_shard_specs=["RR"], + ) + + rank = comm.Get_rank() + tensor_x = np.array([[1, 2, 3, 4], [3, 4, 5, 6]], dtype=np.float32) + tensor_w = np.array([[1, 1], [2, 2], [3, 3], [4, 4]], dtype=np.float32) + + onnx_model = matmul_rs_sr_rr.to_model_proto( + input_types=[FLOAT[2, "s"], FLOAT["s", 2]], + output_types=[FLOAT[2, 2]], + ) + + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + tensor_shard_x = shard_tensor(tensor_x, rank=rank, axis=1, num_shards=2) + tensor_shard_w = shard_tensor(tensor_w, rank=rank, axis=0, num_shards=2) + + result = sess.run(None, {"tensor_x": tensor_shard_x, "tensor_w": tensor_shard_w}) + + expected = np.matmul(tensor_x, tensor_w) + np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) + + def test_matmul2d_rs_rs_rr(self): + @onnxscript.script() + def matmul_rs_rs_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: + return MICROSOFT_OPSET.DistributedMatMul( + tensor_x, + tensor_w, + device_mesh_shape=[2], + device_mesh_elements=[0, 1], + input_shard_specs=["RS[0]", "RS[0]"], + output_shard_specs=["RR"], + ) + + rank = comm.Get_rank() + tensor_x = np.array([[1, 2, 3, 4], [3, 4, 5, 6]], dtype=np.float32) + tensor_w = np.array([[1, 1], [2, 2], [3, 3], [4, 4]], dtype=np.float32) + + # Shape informaton should match the shapes seen by the operator. + # If the tensor W with shape [4, 2] is sharded following "RS[0]", its shape + # should be [4, 1] in ORT when calling ctx->Input(1)->Shape(). + onnx_model = matmul_rs_rs_rr.to_model_proto( + input_types=[FLOAT[2, "s"], FLOAT[4, "t"]], + output_types=[FLOAT[2, 2]], + ) + + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + tensor_shard_x = shard_tensor(tensor_x, rank=rank, axis=1, num_shards=2) + tensor_shard_w = shard_tensor(tensor_w, rank=rank, axis=1, num_shards=2) + + result = sess.run(None, {"tensor_x": tensor_shard_x, "tensor_w": tensor_shard_w}) + + expected = np.matmul(tensor_x, tensor_w) + np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) + + def test_matmul2d_rs_rs_rs(self): + @onnxscript.script() + def matmul2d_rs_rs_rs(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: + return MICROSOFT_OPSET.DistributedMatMul( + tensor_x, + tensor_w, + device_mesh_shape=[2], + device_mesh_elements=[0, 1], + input_shard_specs=["RS[0]", "RS[0]"], + output_shard_specs=["RS[0]"], + ) + + rank = comm.Get_rank() + tensor_x = np.array([[1, 2, 3, 4], [3, 4, 5, 6]], dtype=np.float32) + tensor_w = np.array([[1, 1], [2, 2], [3, 3], [4, 4]], dtype=np.float32) + + onnx_model = matmul2d_rs_rs_rs.to_model_proto( + input_types=[FLOAT[2, "s"], FLOAT[4, "t"]], + output_types=[FLOAT[2, "u"]], + ) + + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + tensor_shard_x = shard_tensor(tensor_x, rank=rank, axis=1, num_shards=2) + tensor_shard_w = shard_tensor(tensor_w, rank=rank, axis=1, num_shards=2) + + result = sess.run(None, {"tensor_x": tensor_shard_x, "tensor_w": tensor_shard_w}) + + expected = shard_tensor(np.matmul(tensor_x, tensor_w), rank=rank, axis=1, num_shards=2) + np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) + + def test_matmul_srr_rr_srr(self): + @onnxscript.script() + def matmul_srr_rr_srr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: + return MICROSOFT_OPSET.DistributedMatMul( + tensor_x, + tensor_w, + device_mesh_shape=[2], + device_mesh_elements=[0, 1], + input_shard_specs=["S[0]RR", "RR"], + output_shard_specs=["S[0]RR"], + ) + + rank = comm.Get_rank() + # Shape [2, 2, 4] + tensor_x = np.array([[[1, 2, 3, 4], [3, 4, 5, 6]], [[1, 2, 3, 4], [3, 4, 5, 6]]], dtype=np.float32) + # Shape [4, 2] + tensor_w = np.array([[1, 1], [2, 2], [3, 3], [4, 4]], dtype=np.float32) + + onnx_model = matmul_srr_rr_srr.to_model_proto( + input_types=[FLOAT["s", 2, 4], FLOAT[4, 2]], + output_types=[FLOAT["s", 2, 2]], + ) + + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + tensor_shard_x = shard_tensor(tensor_x, rank=rank, axis=0, num_shards=2) + tensor_shard_w = tensor_w + + result = sess.run(None, {"tensor_x": tensor_shard_x, "tensor_w": tensor_shard_w}) + + expected = shard_tensor(np.matmul(tensor_x, tensor_w), rank=rank, axis=0, num_shards=2) + np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) + + def test_matmul_srr_rrrr_rsrr(self): + @onnxscript.script() + def matmul_srr_rrrr_rsrr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: + return MICROSOFT_OPSET.DistributedMatMul( + tensor_x, + tensor_w, + device_mesh_shape=[2], + device_mesh_elements=[0, 1], + input_shard_specs=["S[0]RR", "RRRR"], + output_shard_specs=["RS[0]RR"], + ) + + rank = comm.Get_rank() + # Shape [2, 2, 4] + tensor_x = np.array([[[1, 2, 3, 4], [3, 4, 5, 6]], [[1, 2, 3, 4], [3, 4, 5, 6]]], dtype=np.float32) + # Shape [1, 2, 4, 2] + tensor_w = np.array([[[[1, 1], [2, 2], [3, 3], [4, 4]], [[1, 1], [2, 2], [3, 3], [4, 4]]]], dtype=np.float32) + + onnx_model = matmul_srr_rrrr_rsrr.to_model_proto( + input_types=[FLOAT["s", 2, 4], FLOAT[1, 2, 4, 2]], + output_types=[FLOAT[1, "s", 2, 2]], + ) + + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + tensor_shard_x = shard_tensor(tensor_x, rank=rank, axis=0, num_shards=2) + tensor_shard_w = tensor_w + + result = sess.run(None, {"tensor_x": tensor_shard_x, "tensor_w": tensor_shard_w}) + + expected = shard_tensor(np.matmul(tensor_x, tensor_w), rank=rank, axis=1, num_shards=2) + np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) + + def test_matmul_sr_rs_rr(self): + @onnxscript.script() + def matmul_sr_rs_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: + return MICROSOFT_OPSET.DistributedMatMul( + tensor_x, + tensor_w, + device_mesh_shape=[2], + device_mesh_elements=[0, 1], + input_shard_specs=["S[0]R", "RS[0]"], + output_shard_specs=["RR"], + ) + + rank = comm.Get_rank() + # Shape [4, 2] + tensor_x = np.array([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=np.float32) + # Shape [2, 2] + tensor_w = np.array([[1, 1], [2, 2]], dtype=np.float32) + + onnx_model = matmul_sr_rs_rr.to_model_proto( + input_types=[FLOAT["s", 2], FLOAT[2, "t"]], + output_types=[FLOAT["s", "t"]], + ) + + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + tensor_shard_x = shard_tensor(tensor_x, rank=rank, axis=0, num_shards=2) + tensor_shard_w = shard_tensor(tensor_w, rank=rank, axis=1, num_shards=2) + + result = sess.run(None, {"tensor_x": tensor_shard_x, "tensor_w": tensor_shard_w}) + + expected = np.matmul(tensor_x, tensor_w) + np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) + + def test_matmul_rr_rs_rs(self): + @onnxscript.script() + def matmul_rr_rs_rs(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: + return MICROSOFT_OPSET.DistributedMatMul( + tensor_x, + tensor_w, + device_mesh_shape=[2], + device_mesh_elements=[0, 1], + input_shard_specs=["RR", "RS[0]"], + output_shard_specs=["RS[0]"], + ) + + rank = comm.Get_rank() + # Shape [4, 2] + tensor_x = np.array([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=np.float32) + # Shape [2, 4] + tensor_w = np.array([[1, 1, 1, 1], [2, 2, 2, 2]], dtype=np.float32) + + onnx_model = matmul_rr_rs_rs.to_model_proto( + input_types=[FLOAT[4, 2], FLOAT[2, "s"]], + output_types=[FLOAT[4, "t"]], + ) + + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + tensor_shard_x = tensor_x + tensor_shard_w = shard_tensor(tensor_w, rank=rank, axis=1, num_shards=2) + + result = sess.run(None, {"tensor_x": tensor_shard_x, "tensor_w": tensor_shard_w}) + + expected = shard_tensor(np.matmul(tensor_x, tensor_w), rank=rank, axis=1, num_shards=2) + np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) + + def test_matmul_rr_sr_rr(self): + @onnxscript.script() + def matmul_rr_sr_rr(tensor_x: FLOAT, tensor_w: FLOAT) -> FLOAT: + return MICROSOFT_OPSET.DistributedMatMul( + tensor_x, + tensor_w, + device_mesh_shape=[2], + device_mesh_elements=[0, 1], + input_shard_specs=["RR", "S[0]R"], + output_shard_specs=["RR"], + ) + + rank = comm.Get_rank() + # Shape [4, 2] + tensor_x = np.array([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=np.float32) + # Shape [2, 6] + tensor_w = np.array([[1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2]], dtype=np.float32) + + onnx_model = matmul_rr_sr_rr.to_model_proto( + input_types=[FLOAT[4, 2], FLOAT["s", 6]], + output_types=[FLOAT[4, 6]], + ) + + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + tensor_shard_x = tensor_x + tensor_shard_w = shard_tensor(tensor_w, rank=rank, axis=0, num_shards=2) + + result = sess.run(None, {"tensor_x": tensor_shard_x, "tensor_w": tensor_shard_w}) + + expected = np.matmul(tensor_x, tensor_w) + np.testing.assert_allclose(result[0], expected, rtol=1e-5, atol=1e-8) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ortmodule-distributed-test-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ortmodule-distributed-test-ci-pipeline.yml index f05d03bb54f9c..654bc0921556a 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ortmodule-distributed-test-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ortmodule-distributed-test-ci-pipeline.yml @@ -124,7 +124,7 @@ stages: --volume $(Build.BinariesDirectory):/build \ --volume /mnist:/mnist \ onnxruntime_ortmodule_distributed_tests_image \ - bash -c "rm -rf /build/RelWithDebInfo/onnxruntime/ && python3 -m pip install mpi4py && python3 -m pip install /build/RelWithDebInfo/dist/onnxruntime*.whl && mpirun -n 4 -x NCCL_DEBUG=INFO python /onnxruntime_src/onnxruntime/test/python/onnxruntime_test_collective.py" \ + bash -c "rm -rf /build/RelWithDebInfo/onnxruntime/ && python3 -m pip install mpi4py onnxscript && python3 -m pip install /build/RelWithDebInfo/dist/onnxruntime*.whl && mpirun -n 4 -x NCCL_DEBUG=INFO python /onnxruntime_src/onnxruntime/test/python/onnxruntime_test_collective.py && mpirun -n 2 -x NCCL_DEBUG=INFO python /onnxruntime_src/onnxruntime/test/python/onnxruntime_test_distributed.py" \ displayName: 'Run onnxruntime_test_collective.py' condition: succeededOrFailed() timeoutInMinutes: 30