From 9c323106735535b6dab6b476648faac0ad185e21 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 26 Oct 2023 22:33:42 -0700 Subject: [PATCH] Distributed Reshape Implementation (#18068) This DistributedReshape aims at supporting all sharding patterns encountered in llama 2. All patterns found are tested in `TestDistributedReshape` in `onnxruntime_test_distributed.py`. This PR implements algorithms to compute the categories below. - All inputs and outputs are replica, so it's computed like a normal Reshape. - Two-axis fusion (if any of the inputs and outputs are sharded). This category convers, e.g., `[batch, seq, hidden] -> [batch x seq, hidden]`. - Two-axis decomposition (if any of the inputs and outputs are sharded). This category convers, e.g., `[batch x seq, hidden] -> [batch, seq, hidden]`. Review guideline: - Ignore the changes in sharding_spec.h and sharding_spec.cc since they come from another PR #18025. - First, read onnxruntime_test_distributed.py to get familiar with the input/output of DistributedReshape. - Second, check the new APIs in reshape.h/reshape.cc to expose CUDA Reshape kernel to DistributedReshape. - For DistributedReshape, check its `ComputeInternal` for the 3 categories mentioned above. --- cmake/onnxruntime_providers_cuda.cmake | 3 +- cmake/onnxruntime_rocm_hipify.cmake | 2 + .../cuda/collective/distributed_reshape.cc | 861 ++++++++++++++++++ .../cuda/collective/distributed_reshape.h | 40 + .../contrib_ops/cuda/collective/sharding.cc | 8 +- .../cuda/collective/sharding_spec.cc | 14 +- .../cuda/collective/sharding_spec.h | 108 ++- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 8 + .../core/graph/contrib_ops/collective_defs.cc | 45 + onnxruntime/core/providers/cuda/cuda_kernel.h | 10 +- .../core/providers/cuda/tensor/reshape.cc | 75 ++ .../core/providers/cuda/tensor/reshape.h | 59 +- onnxruntime/core/providers/rocm/rocm_kernel.h | 10 +- .../python/onnxruntime_test_distributed.py | 667 ++++++++++++++ 14 files changed, 1870 insertions(+), 40 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc create mode 100644 onnxruntime/contrib_ops/cuda/collective/distributed_reshape.h diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 003012f8da071..02b17ee324f4f 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -38,6 +38,7 @@ "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reshape.cc" ) endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio @@ -246,4 +247,4 @@ install(TARGETS onnxruntime_providers_cuda ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) \ No newline at end of file + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index de1458c120016..4ef0584b0273e 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -103,6 +103,8 @@ if (NOT onnxruntime_USE_NCCL) list(APPEND contrib_ops_excluded_files "collective/sharding.cc") list(APPEND contrib_ops_excluded_files "collective/sharding_spec.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc") endif() set(provider_excluded_files diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc new file mode 100644 index 0000000000000..a0ac40defbee7 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc @@ -0,0 +1,861 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_reshape.h" +#include "sharding.h" +#include "sharding_spec.h" +#include "nccl_kernels.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/tensor/transpose.h" +#include "core/providers/cuda/cuda_check_memory.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +// Return true if src_shape[src_begin:src_end] is the same as +// dst_shape[dst_begin:dst_end]. Otherwise, return false. +// TODO: replace std::vector with gsl::span. +bool CompareSubVectors( + const std::vector& src_shape, + const std::vector& dst_shape, + size_t src_begin, size_t src_end, + size_t dst_begin, size_t dst_end) { + if (src_end - src_begin != dst_end - dst_begin) { + // Sub-vectors have different lengths. + return false; + } + for (size_t src_index = src_begin, dst_index = dst_begin; + src_index < src_end && dst_index < dst_end; + ++src_index, ++dst_index) { + if (src_shape[src_index] != dst_shape[dst_index]) { + // Sub-vectors have different elements. + return false; + } + } + // Sub-vectors have same length and same elements. + return true; +} + +// TODO: replace std::vector with gsl::span. +std::tuple IsTwoAxisFusion( + const std::vector& src_shape, + const std::vector& dst_shape) { + // Return values: + // - bool: whether two consecutive axes are fused. + // - size_t: the axis in destination shape formed by fusing two source axes. + // - size_t: the first axis fused. + // - size_t: the length of fusion. In two-axis fusion considered by this + // function, the length of fusion is always 2. + const size_t src_rank = src_shape.size(); + const size_t dst_rank = dst_shape.size(); + if (src_rank < 2 || dst_rank < 1) { + return std::make_tuple(false, -1, -1, -1); + } + if (src_rank - 1 != dst_rank) { + return std::make_tuple(false, -1, -1, -1); + } + for (size_t i_src = 0; i_src < src_rank; ++i_src) { + if (i_src + 1 > src_rank - 1) { + // We are at src_shape[i] and we need + // src_shape[i + 1] to fuse. + // If we are at the last axis, we cannot fuse. + break; + } + const int64_t prod = src_shape[i_src] * src_shape[i_src + 1]; + + for (size_t i_dst = 0; i_dst < dst_rank; ++i_dst) { + // Check if shape[i_src:i_src+2] (i.e., shape[i_src] and shape[i_src+1]) + // for source tensor are fused into shape[i_dst] for destination tensor. + if (prod != dst_shape[i_dst]) { + continue; + } + // Check if corresponding dimensions before fusion area + // are the same. + const bool prefix_shape_match = CompareSubVectors( + src_shape, + dst_shape, + // Represent src_shape[0:i_src]. + 0, i_src, + // Represent dst_shape[0:i_dst]. + 0, i_dst); + const bool suffix_shape_match = CompareSubVectors( + src_shape, + dst_shape, + // Represent src_shape[i_src+2:]. + i_src + 2, src_rank, + // Represent dst_shape[i_dst+1:]. + i_dst + 1, dst_rank); + if (prefix_shape_match && suffix_shape_match) { + return std::make_tuple( + true, i_dst, i_src, 2); + } + } + } + return std::make_tuple(false, 0, 0, 0); +} + +std::tuple IsTwoAxisDecomposition( + const std::vector& src_shape, + const std::vector& dst_shape) { + // Return values: + // - bool: whether one source axis is decomposed into two consecutive destination axes. + // - size_t: the axis in source shape decomposed into two consecutive destination axes. + // - size_t: the first axis the source axis decomposed into. + // - size_t: the number of decomposed axes. It's always 2 in this function. + return IsTwoAxisFusion(dst_shape, src_shape); +} + +std::vector RepeatVector(const std::vector& vec, int64_t repeat) { + std::vector new_vec; + for (int64_t i = 0; i < repeat; ++i) { + new_vec.insert(new_vec.end(), vec.begin(), vec.end()); + } + return new_vec; +} + +DeviceMesh CreateInterleaveDeviceMesh( + const DeviceMesh& source_mesh, const int64_t repeat) { + // Given a 1-D device mesh [0, 1] and repeat=2, + // return 1-D device mesh [0, 1, 0, 1]. + if (source_mesh.device_mesh_shape.size() != 1) { + throw std::runtime_error("Source mesh shape 1-D."); + } + + // Mesh to return. + DeviceMesh new_mesh; + + std::vector& elements = new_mesh.device_mesh_elements; + for (int64_t i = 0; i < repeat; ++i) { + elements.insert( + elements.end(), + source_mesh.device_mesh_elements.begin(), + source_mesh.device_mesh_elements.end()); + } + + // source mesh must be 1-D so we only care its 1st dimension. + new_mesh.device_mesh_shape.push_back(source_mesh.device_mesh_shape[0] * repeat); + + return new_mesh; +} + +std::tuple ComputeNativeSpecForTwoAxisFusion( + const TensorPartitionSpec& src_spec, + const std::vector& src_shape, + const std::vector& dst_shape, + const int64_t fused_axis_in_src, + const int64_t fusion_axis_in_dst) { + // TODO(wechi): use device mesh stride to support non-1 stride. + // Example: S[0]R, shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1] + // Example: RS[0], shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1, 0, 1] + // Example: S[0]RR, shape=[2, 3, 5], device_mesh=[0, 1] -> S[0]R, shape = [2, 15], device_mesh=[0, 1] + ORT_ENFORCE(src_spec.CountShardingAxes() == 1, "Tensor to be reshaped has too many sharding axes."); + ORT_ENFORCE(src_spec.device_mesh.device_mesh_shape.size() == 1, "Source device mesh be 1-D."); + + if (src_spec.HasNoShard()) { + return std::make_tuple(true, TensorPartitionSpec::CreateAllReplica(dst_shape.size(), src_spec.device_mesh)); + } else if (src_spec.HasShard() && src_spec.OnlyShardAxis(fused_axis_in_src)) { + // Example: S[0]R, shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1] + // Example 1: + // - logical input shape: [2, 8] + // - logical output shape: [16] + // - input sharding spec: S[0]R, device_mesh=[0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 0, 0, 0, 0, 0, 0], (device assignment) + // [1, 1, 1, 1, 1, 1, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15]] + // - Device 0's local tensor (shape: [2, 4]). + // [[ 0, 1, 2, 3, 4, 5, 6, 7]] + // - Device 1's local tensor (shape: [2, 4]). + // [[ 8, 9, 10, 11, 12, 13, 14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [2, 4]. + // 3. Run local reshape (reshape from shape [2, 4] to shape [8]): + // - Device 0's local output tensor. + // [ 0, 1, 2, 3, 4, 5, 6, 7] + // - Device 1's local output tensor. + // [ 8, 9, 10, 11, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - Device assignment by comparing local tensors and logical output tensor: + // [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1] = input device mesh. + // 5. Native output sharding spec: + // - S[0] with device_mesh [0, 1] + // + // Example 2: + // - logical input shape: [8, 2] + // - logical output shape: [16] + // - input sharding spec: S[0]R, device_mesh=[0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0], (device assignment) + // [0, 0], + // [0, 0], + // [0, 0], + // [1, 1], + // [1, 1], + // [1, 1], + // [1, 1]] + // [[ 0, 1], (values) + // [ 2, 3], + // [ 4, 5], + // [ 6, 7], + // [ 8, 9], + // [10, 11], + // [12, 13], + // [14, 15]] + // - Device 0's local tensor (shape: [4, 2]). + // [[ 0, 1], + // [ 2, 3], + // [ 4, 5], + // [ 6, 7]] + // - Device 1's local tensor (shape: [4, 2]). + // [[ 8, 9], + // [10, 11], + // [12, 13], + // [14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [4, 2]. + // 3. Run local reshape (reshape from shape [4, 2] to shape [8]): + // - Device 0's local output tensor. + // [ 0, 1, 2, 3, 4, 5, 6, 7] + // - Device 1's local output tensor. + // [ 8, 9, 10, 11, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - Device assignment by comparing local tensors and logical output tensor: + // [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1] = input device mesh. + // 5. Native output sharding spec: + // - S[0] with device_mesh [0, 1] + // + // Example 3: + // - logical input shape: [8, 2] + // - logical output shape: [16] + // - input sharding spec: S[0]R, device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0], (device assignment) + // [0, 0], + // [1, 1], + // [1, 1], + // [0, 0], + // [0, 0], + // [1, 1], + // [1, 1]] + // [[ 0, 1], (values) + // [ 2, 3], + // [ 4, 5], + // [ 6, 7], + // [ 8, 9], + // [10, 11], + // [12, 13], + // [14, 15]] + // - Device 0's local tensor (shape: [4, 2]). + // [[ 0, 1], + // [ 2, 3], + // [ 8, 9], + // [10, 11]] + // - Device 1's local tensor (shape: [4, 2]). + // [[ 4, 5], + // [ 6, 7], + // [12, 13], + // [14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [4, 2]. + // 3. Run local reshape (reshape from shape [4, 2] to shape [8]): + // - Device 0's local output tensor. + // [ 0, 1, 2, 3, 8, 9, 10, 11] + // - Device 1's local output tensor. + // [ 4, 5, 6, 7, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - Device assignment by comparing local tensors and logical output tensor: + // [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1] = input device mesh. + // 5. Native output sharding spec: + // - S[0] with device_mesh [0, 1, 0, 1] + + // Reuse original device mesh but shard the fusion axis in output tensor. + auto dst_spec = TensorPartitionSpec::CreateOneTensorAxisOneDeviceMeshAxisSharding( + dst_shape.size(), src_spec.device_mesh, fusion_axis_in_dst, /* 1-D mesh */ 0); + return std::make_tuple(true, dst_spec); + } else if (src_spec.HasShard() && src_spec.OnlyShardAxis(fused_axis_in_src + 1)) { + // Example 1 of determining native output sharding spec: + // - logical input shape: [3, 4] + // - logical output shape: [12] + // - input sharding spec: RS[0], device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 1, 0, 1], (device assignment) + // [0, 1, 0, 1], + // [0, 1, 0, 1]] + // [[0, 1, 2, 3], (values) + // [4, 5, 6, 7], + // [8, 9, 10, 11]], + // - Device 0's local tensor. + // [[0, 0], + // [0, 0], + // [0, 0]] + // [[0, 2], + // [4, 6], + // [8, 10]], + // - Device 1's local tensor. + // [[1, 1], + // [1, 1], + // [1, 1]] + // [[1, 3], + // [5, 7], + // [9, 11]], + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [6] by fusing both axes in shape [3, 2]. + // 3. Run local reshape (reshape from [3, 2] to [6]): + // - Device 0's local output tensor. + // [0, 0, 0, 0, 0, 0] + // [0, 2, 4, 6, 8, 10] + // - Device 1's local output tensor. + // [1, 1, 1, 1, 1, 1] + // [1, 3, 5, 7, 9, 11] + // 4. Determine native output sharding spec by comparing local output tensors and logical tensor. + // - Logical output tensor: + // [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + // - S[0] with device_mesh = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] = [0, 1, 0, 1] * (first fused dimension). + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1, 0, 1] * (first fused dimension) = [0, 1, 0, 1] * 3 + // + // Example 2 of determining native output sharding spec: + // - logical input shape: [3, 8] + // - logical output shape: [24] + // - input sharding spec: RS[0], device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 1, 1, 0, 0, 1, 1], (device assignment) + // [0, 0, 1, 1, 0, 0, 1, 1], + // [0, 0, 1, 1, 0, 0, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15], + // [16, 17, 18, 19, 20, 21, 22, 23]] + // - Device 0's local tensor (shape: [3, 4]). + // [[0, 0, 0, 0], + // [0, 0, 0, 0], + // [0, 0, 0, 0]] + // [[ 0, 1, 4, 5], + // [ 8, 9, 12, 13], + // [16, 17, 20, 21]] + // - Device 1's local tensor (shape: [3, 4]). + // [[1, 1, 1, 1], + // [1, 1, 1, 1], + // [1, 1, 1, 1]] + // [[ 2, 3, 6, 7], + // [10, 11, 14, 15], + // [18, 19, 22, 23]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [12] by fusing both axes in shape [3, 4]. + // 3. Run local reshape (reshape from [3, 4] to [12]): + // - Device 0's local output tensor . + // [0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21] + // - Device 1's local output tensor . + // [2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23] + // 4. Determine native output sharding spec from local output tensors. + // - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23] + // - [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + // - S[0] with device_mesh = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] = . + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1, 0, 1] * (first fused dimension) = [0, 1, 0, 1] * 3 + // + // Example 3: + // - logical input shape: [2, 8] + // - logical output shape: [16] + // - input sharding spec: RS[0], device_mesh=[0, 1, 0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 1, 1, 0, 0, 1, 1], (device assignment) + // [0, 0, 1, 1, 0, 0, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15]] + // - Device 0's local tensor (shape: [2, 4]). + // [[0, 0, 0, 0], + // [0, 0, 0, 0]] + // [[ 0, 1, 4, 5], + // [ 8, 9, 12, 13]] + // - Device 1's local tensor (shape: [2, 4]). + // [[1, 1, 1, 1], + // [1, 1, 1, 1]] + // [[ 2, 3, 6, 7], + // [10, 11, 14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [2, 4]. + // 3. Run local reshape (reshape from [2, 4] to [8]): + // - Device 0's local output tensor . + // [ 0, 1, 4, 5, 8, 9, 12, 13] + // - Device 1's local output tensor . + // [ 2, 3, 6, 7, 10, 11, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + // - S[0] with device_mesh = [0, 1, 0, 1, 0, 1, 0, 1] = [0, 1, 0, 1] * (first fused dimension). + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1, 0, 1] * (first fused dimension) = [0, 1, 0, 1] * 2 + // + // Example 4: + // - logical input shape: [2, 8] + // - logical output shape: [16] + // - input sharding spec: RS[0], device_mesh=[0, 1] + // 1. Device allocation of the original input tensor: + // - Logical tensor. + // [[0, 0, 0, 0, 1, 1, 1, 1], (device assignment) + // [0, 0, 0, 0, 1, 1, 1, 1]] + // [[ 0, 1, 2, 3, 4, 5, 6, 7], (values) + // [ 8, 9, 10, 11, 12, 13, 14, 15]] + // - Device 0's local tensor (shape: [2, 4]). + // [[0, 0, 0, 0], + // [0, 0, 0, 0]] + // [[ 0, 1, 2, 3], + // [ 8, 9, 10, 11]] + // - Device 1's local tensor (shape: [2, 4]). + // [[1, 1, 1, 1], + // [1, 1, 1, 1]] + // [[ 4, 5, 6, 7], + // [12, 13, 14, 15]] + // 2. Deduce local output shape: + // - In the logical Reshape, the 1st and 2nd logical axes are fused, + // so are the corresponding local axes. + // - Local output shape: [8] by fusing both axes in shape [2, 4]. + // 3. Run local reshape (reshape from [2, 4] to [8]): + // - Device 0's local output tensor . + // [ 0, 1, 2, 3, 8, 9, 10, 11] + // - Device 1's local output tensor . + // [ 4, 5, 6, 7, 12, 13, 14, 15] + // 4. Determine native output sharding spec from local output tensors. + // - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + // - [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1] + // - S[0] with device_mesh = [0, 1, 0, 1] = [0, 1] * (first fused dimension). + // 5. Native output sharding spec: + // - S[0] with device_mesh = [0, 1] * (first fused dimension) = [0, 1] * 2 = [0, 1, 0, 1] + + // The output device mesh is the repeats of the original device. + // Let's use Python syntax. If the original device mesh is [0, 1, 0, 1], and + // the first fused dimension is 3, then the output device mesh is [0, 1, 0, 1] * 3. + auto dst_device_mesh = DeviceMesh::Create1D( + src_spec.device_mesh.device_mesh_elements, + src_shape[fused_axis_in_src]); + // Sharding happens in the fusion axis with the new device mesh. + auto dst_spec = TensorPartitionSpec::CreateOneTensorAxisOneDeviceMeshAxisSharding( + dst_shape.size(), dst_device_mesh, fusion_axis_in_dst, /* 1-D mesh */ 0); + return std::make_tuple(true, dst_spec); + } else if (src_spec.HasShard() && (src_spec.GetPartitionAxis() < fused_axis_in_src || src_spec.GetPartitionAxis() > fused_axis_in_src + 1)) { + // It's two-axis fusion but the fused axes is not sharded. + // Example: S[0]RR, shape=[2, 3, 5], device_mesh=[0, 1] -> S[0]R, shape = [2, 15], device_mesh=[0, 1] + auto dst_spec = TensorPartitionSpec::CreateByDropOneAxis( + src_spec, fused_axis_in_src + 1); + return std::make_tuple(true, dst_spec); + } else { + return std::make_tuple(false, TensorPartitionSpec()); + } +} + +// Arguments: +// - device_elements: a vector of device IDs. +// It should only contain unique device IDs or +// repeats of a list of unique device IDs. Otherwise, +// (0, 0) is returned. +// Returns: +// - count per device ID (all device IDs should have the same count) +// - number of unique device IDs +// Examples: +// - [0, 1] -> (2, 1) +// - [0, 1, 2, 0, 1, 2] -> (2, 3) +std::tuple ComputeRepeatAndRepeatStride( + const std::vector& device_elements) { + int64_t first_device_id = device_elements.at(0); + int64_t first_device_id_count = 0; + for (size_t i = 0; i < device_elements.size(); ++i) { + if (device_elements.at(i) == first_device_id) { + ++first_device_id_count; + } + } + size_t repeat_stride = device_elements.size() / first_device_id_count; + + // Check if the device mesh pattern is supported. + // Supported examples: [0, 1, 2] and [0, 1, 0, 1, 0, 1]. + // Unsupported examples: [0, 1, 2, 1, 2, 0] and [0, 1, 2, 0]. + for (size_t repeat = 0; repeat < first_device_id_count; ++repeat) { + for (size_t device_id = 0; device_id < repeat_stride; ++device_id) { + ORT_ENFORCE( + device_elements.at(repeat * repeat_stride + device_id) == device_elements.at(device_id), + "Unsupported device mesh pattern."); + } + } + + // If device_mesh=[0, 1, 2, 0, 1, 2], returns (2, 3), which means + // - each device repeats twice for "2" in (2, 3). + // - there are 3 unique devices for "3" in (2, 3). + return std::make_tuple(first_device_id_count, repeat_stride); +} + +std::tuple ComputeNativeSpecForTwoAxisDecomposition( + const TensorPartitionSpec& src_spec, + const std::vector& src_shape, + const std::vector& dst_shape, + const int64_t decomposed_axis_in_src, + const int64_t decomposition_axis_in_dst) { + // TODO(wechi): use device mesh stride to support non-1 stride. + // Example: S[0], shape=[8], device_mesh=[0, 1] -> S[0]R + // Example: S[0], shape=[8], device_mesh=[0, 1] -> RS[0] + // Example: S[0], shape=[8], device_mesh=[0, 1, 0, 1] -> S[0]R + // Example: S[0], shape=[8], device_mesh=[0, 1, 0, 1] -> RS[0] + // Example: RS[0]R, shape=[8], device_mesh=[0, 1] -> RS[0]RR + // Example: RS[0]R, shape=[8], device_mesh=[0, 1] -> RRS[0]R + if (src_spec.CountShardingAxes() != 1) { + throw std::runtime_error("Too many sharding axes."); + } + if (src_spec.device_mesh.device_mesh_shape.size() != 1) { + throw std::runtime_error("Source device mesh be 1-D."); + } + + if (src_spec.HasNoShard()) { + return std::make_tuple(true, TensorPartitionSpec::CreateAllReplica(dst_shape.size(), src_spec.device_mesh)); + } else if (src_spec.OnlyShardAxis(decomposed_axis_in_src)) { + const int64_t device_stride = src_shape[decomposed_axis_in_src] / src_spec.device_mesh.device_mesh_shape[0]; + if (device_stride >= dst_shape[decomposition_axis_in_dst + 1] && device_stride % dst_shape[decomposition_axis_in_dst + 1] == 0) { + // Since 2nd decomposition dimension is a factor of device stride, + // Sharding happens at 1st decomposition axis in dst. + // device_stride = 10 + // S[0], shape=[20], device=[0, 1] -> S[0]R, shape=[2, 10], device=[0, 1] + // + // device_stride = 8 + // S[0], shape=[16], device=[0, 1] -> RS[0], shape=[1, 16], device=[0, 1] + // + // device_stride = 8 + // S[0], shape=[16], device=[0, 1] -> S[0]R, shape=[4, 4], device=[0, 1] + std::vector dst_axis_specs; + for (size_t src_axis = 0; src_axis < src_shape.size(); ++src_axis) { + if (src_axis != decomposed_axis_in_src) { + // Sharding spec is copied if the axis is not decomposed. + // E.g, shape [5, 6] -> Reshape -> shape [5, 3, 2] + // The spec for "5" is copied. + dst_axis_specs.push_back(AxisPartitionSpec::CreateCopy(src_spec.GetAxisSpec(src_axis))); + } else if (dst_shape[decomposition_axis_in_dst] == 1) { + // S[0] -> RS[0] + // E.g., shape [5] -> Reshape -> shape [1, 5] + // The spec for "5" is copied and "1" is replica. + // This reshape only adds a dummy new axis without affecting + // the underlying sharding status. + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + } else { + // S[0] -> S[0]R + // E.g., shape [5] -> Reshape -> shape [5, 1] + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + } + } + // Now, we know sharding happens at decomposed_axis_in_src axis in destination tensor. + // - effective_device_stride along decomposed_axis_in_src: device_stride / dst_shape[decomposed_axis_in_src + 1] + // - The original device patterns repeats: dst_shape[decomposed_axis_in_src] / effective_device_stride times. + const int64_t effective_device_stride = device_stride / dst_shape[decomposed_axis_in_src + 1]; + // How many times a device ID changes along decomposed_axis_in_src axis in destination tensor. + const int64_t number_of_device_changes = dst_shape[decomposed_axis_in_src] / effective_device_stride; + if ((size_t)number_of_device_changes != src_spec.device_mesh.device_mesh_elements.size()) { + throw std::runtime_error("Not supported. Resharding is required."); + } + auto dst_device_mesh = CreateInterleaveDeviceMesh( + src_spec.device_mesh, 1); + return std::make_tuple(true, TensorPartitionSpec::Create(dst_axis_specs, dst_device_mesh)); + } else if (dst_shape[decomposition_axis_in_dst + 1] > device_stride && dst_shape[decomposition_axis_in_dst + 1] % device_stride == 0) { + // Since 2nd decomposition dimension is a multiple of device stride, + // sharding happens at 2nd decomposition axis in dst. + // stride = 4 + // S[0], shape=[8], device=[0, 1] -> S[0]R, shape=[4, 2], device=[0, 1] + // + // stride = 8 + // S[0], shape=[32], device=[0, 1, 0, 1] -> RS[0], shape=[2, 16], device=[0, 1] + std::vector dst_axis_specs; + // How many times a device ID appears. + // E.g., [0, 1, 0, 1, 0, 1] -> 3 + int64_t repeats = 0; + // Number of unique devices. + // E.g., [0, 1, 0, 1, 0, 1] -> 2 + int64_t repeat_stride = 0; + DeviceMesh dst_device_mesh; + std::tie(repeats, repeat_stride) = ComputeRepeatAndRepeatStride(src_spec.device_mesh.device_mesh_elements); + for (size_t src_axis = 0; src_axis < src_shape.size(); ++src_axis) { + if (src_axis != decomposed_axis_in_src) { + dst_axis_specs.push_back(AxisPartitionSpec::CreateCopy(src_spec.GetAxisSpec(src_axis))); + } else if (dst_shape[decomposition_axis_in_dst] == 1) { + // S[0] -> RS[0] + // E.g., shape [5] -> Reshape -> shape [1, 5] + // In this case "1" is added as a dummy axis without affecting + // the underlying sharding status, so we just copy the spec + // for input "5" to output "5". + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_device_mesh = src_spec.device_mesh; + } else if (dst_shape[decomposition_axis_in_dst + 1] == 1) { + // S[0] -> S[0]R + // E.g., shape [5] -> Reshape -> shape [5, 1] + // In this case "1" is added as a dummy axis without affecting + // the underlying sharding status, so we just copy the spec + // for input "5" to output "5". + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_device_mesh = src_spec.device_mesh; + } else if (repeats == 1 && dst_shape[decomposition_axis_in_dst + 1] == device_stride * repeat_stride) { + // S[0] -> RS[0] + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + dst_device_mesh = src_spec.device_mesh; + } else if (repeats != 1 && dst_shape[decomposition_axis_in_dst + 1] % (device_stride * repeat_stride) == 0) { + // S[0] -> RS[0] + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateShard(0)); + // Extract [0, 1] from [0, 1, 0, 1]. + std::vector unique_device_mesh_elements( + src_spec.device_mesh.device_mesh_elements.begin(), + src_spec.device_mesh.device_mesh_elements.begin() + repeat_stride); + // Compute new repeats. + // Example of repeats change from 2 to 1: + // [16]-shape tensor [2, 8]-shape tensor + // with 1-D device mesh -> Reshape -> with 1-D device mesh + // [0, 1, 0, 1] (repeats=2) [0, 1] (repeats=1) + const int64_t new_repeat = dst_shape[decomposition_axis_in_dst + 1] / (device_stride * repeat_stride); + dst_device_mesh.device_mesh_shape.push_back(repeat_stride); + dst_device_mesh.device_mesh_elements = RepeatVector(unique_device_mesh_elements, new_repeat); + } else { + throw std::runtime_error("Not supported. Resharding is required."); + } + } + return std::make_tuple(true, TensorPartitionSpec::Create(dst_axis_specs, dst_device_mesh)); + } else { + // Not supported. Resharding is required. + return std::make_tuple(false, TensorPartitionSpec()); + } + } else { + // Source tensor is sharded on non-decomposed axis. + std::vector dst_axis_specs; + for (size_t src_axis = 0; src_axis < src_shape.size(); ++src_axis) { + if (src_axis != decomposed_axis_in_src) { + dst_axis_specs.push_back(AxisPartitionSpec::CreateCopy(src_spec.GetAxisSpec(src_axis))); + } else { + // R -> RR + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + dst_axis_specs.push_back(AxisPartitionSpec::CreateReplica()); + } + } + + return std::make_tuple(true, TensorPartitionSpec::Create(dst_axis_specs, src_spec.device_mesh)); + } +} + +// Arguments: +// global_data_shape: logical shape of Reshape's 1st input. +// global_shape_span: logical content of Reshape's 2nd input. +// Returns: +// logical shape of Reshape's output. +inline TensorShape InferDistributedReshapeLogicalOutputShape( + const TensorShape& global_data_shape, + const gsl::span& global_shape_span, + const int64_t allow_zero) { + return onnxruntime::cuda::InferReshapeOutputShape( + global_data_shape, + global_shape_span, + allow_zero); +} + +template +DistributedReshape::DistributedReshape(const OpKernelInfo& info) : DistributedKernel(info) { + allow_zero_ = info.GetAttrOrDefault("allowzero", static_cast(0)); +} + +template +Status DistributedReshape::ComputeInternal(OpKernelContext* context) const { + ORT_ENFORCE(context != nullptr); + auto data_tensor = context->Input(0); + auto shape_tensor = context->Input(1); + const auto& data_sharding_spec = input_shard_specs_.at(0); + const auto& shape_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + if (data_sharding_spec.HasNoShard() && shape_sharding_spec.HasNoShard() && output_sharding_spec.HasNoShard()) { + // Case: all inputs and outputs are not sharded. + const auto target_shape = onnxruntime::cuda::InferReshapeOutputShape( + data_tensor, + shape_tensor, + allow_zero_); + + auto output_tensor = context->Output(0, target_shape); + + // Copy data from input from output. + return FuncReshape( + this, + context, + data_tensor, + shape_tensor, + allow_zero_, + output_tensor); + } else { + ORT_ENFORCE(shape_sharding_spec.HasNoShard(), + "Shape tensor should not be sharded because it will trigger communication. " + "If sharding shape is needed, please request this feature on Github."); + ORT_ENFORCE(shape_tensor->Shape().NumDimensions() == 1, "Shape must be a 1-D tensor."); + const auto original_data_shape = ComputeOriginShape(data_tensor->Shape(), data_sharding_spec); + const auto original_output_shape = InferDistributedReshapeLogicalOutputShape( + original_data_shape, + shape_tensor->template DataAsSpan(), + allow_zero_); + + // TODO: remove below code after replacing std::vector with TensorShape in other APIs. + std::vector src_shape(original_data_shape.GetDims().begin(), original_data_shape.GetDims().end()); + std::vector dst_shape(original_output_shape.GetDims().begin(), original_output_shape.GetDims().end()); + + // Case: Two axis fusion + bool is_two_axis_fusion = false; + size_t two_axis_fusion_axis_in_dst = 0; + size_t two_axis_fusion_first_fused_axis_in_src = 0; + size_t two_axis_fusion_fused_axis_count = 0; + std::tie( + is_two_axis_fusion, + two_axis_fusion_axis_in_dst, + two_axis_fusion_first_fused_axis_in_src, + two_axis_fusion_fused_axis_count) = IsTwoAxisFusion(src_shape, dst_shape); + + if (is_two_axis_fusion) { + bool is_supported = false; + TensorPartitionSpec native_dst_spec; + std::tie(is_supported, native_dst_spec) = ComputeNativeSpecForTwoAxisFusion( + data_sharding_spec, + src_shape, + dst_shape, + two_axis_fusion_first_fused_axis_in_src, + two_axis_fusion_axis_in_dst); + + if (is_supported && native_dst_spec == output_sharding_spec) { + // In this case, we can apply Reshape with local shape on local tensor without resharding. + // Those local output tensors match the output tensors defined by + // sharding the logical tensor following the native sharding spec. + TensorShape local_shape = ComputeShardShape(original_output_shape, native_dst_spec); + auto output_tensor = context->Output(0, local_shape); + return FuncReshape( + this, + context, + data_tensor, + shape_tensor, + allow_zero_, + output_tensor); + } else { + // TODO: Reshape outputs from `native_dst_spec` to `output_sharding_spec`. + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported reshape pattern."); + } + } + + // Case: Two axis decomposition + bool is_two_axis_decomposition = false; + size_t two_axis_decomposition_decomposed_axis_in_src = 0; + size_t two_axis_decomposition_first_factor_axis_in_dst = 0; + size_t two_axis_decomposition_factor_axis_count_in_dst = 0; + std::tie( + is_two_axis_decomposition, + two_axis_decomposition_decomposed_axis_in_src, + two_axis_decomposition_first_factor_axis_in_dst, + two_axis_decomposition_factor_axis_count_in_dst) = IsTwoAxisDecomposition(src_shape, dst_shape); + + if (is_two_axis_decomposition) { + bool is_supported = false; + TensorPartitionSpec native_dst_spec; + std::tie(is_supported, native_dst_spec) = ComputeNativeSpecForTwoAxisDecomposition( + data_sharding_spec, + src_shape, + dst_shape, + two_axis_decomposition_decomposed_axis_in_src, + two_axis_decomposition_first_factor_axis_in_dst); + + if (is_supported && native_dst_spec == output_sharding_spec) { + // In this case, we can apply Reshape with local shape on local tensor without resharding. + // Those local output tensors match the output tensors defined by + // sharding the logical tensor following the native sharding spec. + TensorShape local_shape = ComputeShardShape(original_output_shape, native_dst_spec); + auto output_tensor = context->Output(0, local_shape); + return FuncReshape( + this, + context, + data_tensor, + shape_tensor, + allow_zero_, + output_tensor); + } else { + // TODO: Reshape outputs from `native_dst_spec` to `output_sharding_spec`. + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported reshape pattern."); + } + } + } + + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported reshape pattern."); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReshape, + kMSDomain, + 1, + int64_t, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReshape); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReshape, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReshape); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReshape, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .AllocateInputsContiguously() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReshape); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.h b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.h new file mode 100644 index 0000000000000..e251c3cdc38d7 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "sharding.h" +#include "core/framework/tensor_shape.h" +#include "core/providers/cuda/tensor/reshape.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedReshape final : public DistributedKernel { + public: + explicit DistributedReshape(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t allow_zero_; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding.cc b/onnxruntime/contrib_ops/cuda/collective/sharding.cc index dfd5f589355df..b6b509023a1a9 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharding.cc @@ -30,7 +30,7 @@ void GatherTensor( const Tensor* tensor, Tensor* gathered) { const int64_t shard_axis = spec.GetPartitionAxis(); - const int64_t shard_count = spec.GetPartitionCount(shard_axis); + const int64_t shard_count = spec.GetUniqueDeviceCount(shard_axis); FuncAllGather( nccl_kernel, @@ -51,7 +51,7 @@ std::unique_ptr GatherTensor( const TensorPartitionSpec& spec, const Tensor* tensor) { const int64_t shard_axis = spec.GetPartitionAxis(); - const int64_t shard_count = spec.GetPartitionCount(shard_axis); + const int64_t shard_count = spec.GetUniqueDeviceCount(shard_axis); TensorShape gathered_shape(tensor->Shape()); gathered_shape[shard_axis] *= shard_count; @@ -82,7 +82,7 @@ void ShardTensor( const Tensor* tensor, Tensor* shard_tensor) { const int64_t shard_axis = spec.GetPartitionAxis(); - const int64_t shard_count = spec.GetPartitionCount(shard_axis); + const int64_t shard_count = spec.GetUniqueDeviceCount(shard_axis); TensorShape shard_shape = ComputeShardShape( tensor->Shape(), shard_axis, @@ -118,7 +118,7 @@ std::unique_ptr ShardTensor( TensorShape shard_shape = ComputeShardShape( tensor->Shape(), spec.GetPartitionAxis(), - spec.GetPartitionCount(spec.GetPartitionAxis())); + spec.GetUniqueDeviceCount(spec.GetPartitionAxis())); auto shard_buffer = Tensor::Create(tensor->DataType(), shard_shape, alloc); // Shard with pre-allocated buffer. diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc index 220938f3ceaef..20c936e1b6718 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.cc @@ -129,7 +129,7 @@ TensorShape ComputeOriginShape(const TensorShape& shard_shape, const TensorParti } TensorShape shape(shard_shape); const int64_t axis = spec.GetPartitionAxis(); - shape[axis] *= spec.GetPartitionCount(axis); + shape[axis] *= spec.GetUniqueDeviceCount(axis); return shape; } @@ -140,7 +140,15 @@ TensorShape ComputeShardShape(const TensorShape& shape, const TensorPartitionSpe return shard_shape; } const int64_t axis = spec.GetPartitionAxis(); - shard_shape[axis] /= spec.GetPartitionCount(axis); + const int64_t unique_device_count = spec.GetUniqueDeviceCount(axis); + ORT_ENFORCE(shard_shape[axis] % unique_device_count == 0, "Number of shards must be divisible by sharded axis' dimension."); + // If a [8, 16]-tensor is shared by device mesh [0, 1, 0, 1] along axis=1 (2nd axis), + // the local tensors on device 0 & 1 have same shape [8, 8 (from 16/2)] instead of + // [8, 4 (from 16/4)]. The reason is that + // - First, the original tensor are split into 4 sub-tensors [8, 4] along the 2nd axis. + // - The 1st and 3rd sub-tensors are concatenated along axis=1 to one tensor on device 0. + // - The 2nd and 4th sub-tensors are concatenated along axis=1 to one tensor on device 1. + shard_shape[axis] /= unique_device_count; return shard_shape; } @@ -202,7 +210,7 @@ bool CanShard(const TensorShape& shape, const TensorPartitionSpec& spec) { if (axis < 0 || gsl::narrow(axis) >= shape.NumDimensions()) { return false; } - if (shape[axis] % spec.GetPartitionCount(axis) != 0) { + if (shape[axis] % spec.GetDeviceCount(axis) != 0) { return false; } return true; diff --git a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h index 6bdf5699c2682..5185c41e6888c 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h +++ b/onnxruntime/contrib_ops/cuda/collective/sharding_spec.h @@ -76,6 +76,43 @@ class DeviceMesh { void Print() const { std::cout << ToString() << std::endl; } + + static DeviceMesh Create1D(std::vector device_mesh_elements, size_t repeats = 1) { + DeviceMesh device_mesh; + device_mesh.device_mesh_shape.push_back(device_mesh_elements.size() * repeats); + for (size_t i = 0; i < repeats; ++i) { + device_mesh.device_mesh_elements.insert( + device_mesh.device_mesh_elements.end(), + device_mesh_elements.begin(), + device_mesh_elements.end()); + } + return device_mesh; + } + + // If the two meshes have the same shape and elements, return true. + // Otherwise, return false. + bool operator==(const DeviceMesh& other) const { + if (device_mesh_shape.size() != other.device_mesh_shape.size() || + device_mesh_elements.size() != other.device_mesh_elements.size()) { + return false; + } + + for (size_t i = 0; i < device_mesh_elements.size(); ++i) { + if (device_mesh_elements.at(i) != other.device_mesh_elements.at(i)) { + return false; + } + } + for (size_t i = 0; i < device_mesh_shape.size(); ++i) { + if (device_mesh_shape.at(i) != other.device_mesh_shape.at(i)) { + return false; + } + } + return true; + } + + bool operator!=(const DeviceMesh& other) const { + return !(*this == other); + } }; class AxisPartitionSpec { @@ -114,6 +151,10 @@ class AxisPartitionSpec { return AxisPartitionSpec(Condition::Shard, device_mesh_axis); } + static AxisPartitionSpec CreateCopy(const AxisPartitionSpec& spec) { + return AxisPartitionSpec(spec.cond, spec.device_mesh_axis); + } + // A normal ctor. // TODO(wechi): Consider to hide it and revise the `public` members/functions // exposed to the user. @@ -132,6 +173,14 @@ class AxisPartitionSpec { void Print() const { std::cout << ToString() << std::endl; } + + bool operator==(const AxisPartitionSpec& other) const { + return cond == other.cond && device_mesh_axis == other.device_mesh_axis; + } + + bool operator!=(const AxisPartitionSpec& other) const { + return !(*this == other); + } }; // Return true if `axis` is a valid axis index for a tensor of rank `rank`. @@ -193,6 +242,32 @@ class TensorPartitionSpec { // const TensorPartitionSpec& spec, int64_t new_shard_axis) { // } + // Copy-construct `spec` but with all tensor axes replicated. + // The new spec have the same number of axis specs and the same device mesh. + static TensorPartitionSpec CreateAllReplica( + const size_t rank, const DeviceMesh& device_mesh) { + std::vector axis_specs(rank, AxisPartitionSpec::CreateReplica()); + return TensorPartitionSpec::Create(axis_specs, device_mesh); + } + + static TensorPartitionSpec CreateOneTensorAxisOneDeviceMeshAxisSharding( + const size_t rank, const DeviceMesh& device_mesh, const size_t tensor_axis, const size_t device_mesh_axis) { + std::vector axis_specs(rank, AxisPartitionSpec::CreateReplica()); + axis_specs[tensor_axis] = AxisPartitionSpec::CreateShard(device_mesh_axis); + return TensorPartitionSpec::Create(axis_specs, device_mesh); + } + + static TensorPartitionSpec CreateByDropOneAxis( + const TensorPartitionSpec& TensorPartitionSpec, const size_t axis_to_drop) { + std::vector axis_specs; + for (size_t i = 0; i < TensorPartitionSpec.axis_specs.size(); ++i) { + if (i != axis_to_drop) { + axis_specs.push_back(TensorPartitionSpec.axis_specs[i]); + } + } + return TensorPartitionSpec::Create(axis_specs, TensorPartitionSpec.device_mesh); + } + // Helper to debug and generate error message; e.g., // "TensorPartitionSpec{RS[0], Device Mesh: DeviceMesh{Shape: [4,], Elements: [0,1,2,3,]}}". std::string ToString() const { @@ -303,7 +378,7 @@ class TensorPartitionSpec { // Return the number of shards along the first sharded tensor axis. // This value matches the number of devices along the associated mesh axis. // Return 1 if there is no sharding. - int64_t GetPartitionCount(int64_t axis) const { + int64_t GetDeviceCount(int64_t axis) const { ValidateAxisIndex(axis, Rank()); auto axis_spec = GetAxisSpec(axis); if (axis_spec.cond == AxisPartitionSpec::Condition::Replica) { @@ -312,6 +387,37 @@ class TensorPartitionSpec { return device_mesh.device_mesh_shape.at(axis_spec.device_mesh_axis); } } + + // Similar to GetDeviceCount(), but returns the number of unique devices + // along the first sharded tensor axis. + int64_t GetUniqueDeviceCount(int64_t axis) const { + ValidateAxisIndex(axis, Rank()); + auto axis_spec = GetAxisSpec(axis); + if (axis_spec.cond == AxisPartitionSpec::Condition::Replica) { + return 1; + } else { + std::set device_ids( + device_mesh.device_mesh_elements.begin(), + device_mesh.device_mesh_elements.end()); + return device_ids.size(); + } + } + + bool operator==(const TensorPartitionSpec& other) const { + if (axis_specs.size() != other.axis_specs.size()) { + return false; + } + for (size_t i = 0; i < axis_specs.size(); ++i) { + if (!(axis_specs.at(i) == other.axis_specs.at(i))) { + return false; + } + } + return device_mesh == other.device_mesh; + } + + bool operator!=(const TensorPartitionSpec& other) const { + return !(*this == other); + } }; // Parse "[0, 1, 2, 3]" as std::vector{0, 1, 2, 3}. diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index e762a80cb0e2f..29ca8124bfd05 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -165,6 +165,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedSlice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSlice); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReshape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReshape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReshape); #endif template <> @@ -334,6 +338,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 97befe2a58301..8082b8c010e91 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -191,6 +191,51 @@ void RegisterCollectiveOps() { .Output(0, "output", "Sliced data tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types.") .TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types"); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReshape) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr( + "allowzero", + "(Optional) By default, when any value in the 'shape' input is equal to zero " + "the corresponding dimension value is copied from the input tensor dynamically. " + "allowzero=1 indicates that if any value in the 'shape' input is set to zero, " + "the zero value is honored, similar to NumPy.", + AttributeProto::INT, + static_cast(0)) + .Input(0, "data", "An input tensor.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "Specified shape for output.", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "reshaped", "Reshaped data.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types."); } } // namespace contrib diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index f8b92eface52f..e3106e41e77c8 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -176,17 +176,17 @@ class CudaKernel : public OpKernel { return provider_->ComputeStream(); } + inline Status CopyTensor(const Tensor& src, Tensor& dst, onnxruntime::Stream& stream) const { + auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(src.Location().device, dst.Location().device); + return gpu_data_transfer->CopyTensorAsync(src, dst, stream); + } + protected: template inline const T* GetConstOnes(size_t count, cudaStream_t stream) const { return provider_->template GetConstOnes(count, stream); } - inline Status CopyTensor(const Tensor& src, Tensor& dst, onnxruntime::Stream& stream) const { - auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(src.Location().device, dst.Location().device); - return gpu_data_transfer->CopyTensorAsync(src, dst, stream); - } - inline int GetDeviceId() const { return provider_->GetDeviceId(); } private: diff --git a/onnxruntime/core/providers/cuda/tensor/reshape.cc b/onnxruntime/core/providers/cuda/tensor/reshape.cc index 3c6d900cee9a4..ab364c274a32d 100644 --- a/onnxruntime/core/providers/cuda/tensor/reshape.cc +++ b/onnxruntime/core/providers/cuda/tensor/reshape.cc @@ -6,6 +6,81 @@ namespace onnxruntime { namespace cuda { +TensorShape InferReshapeOutputShape( + const TensorShape& data_tensor_shape, // Data tensor's shape. + const gsl::span& shape_span, // Shape that data tensor reshape to. + bool allow_zero) { + TensorShapeVector shape_vector(shape_span.begin(), shape_span.end()); + ReshapeHelper helper(data_tensor_shape, shape_vector, allow_zero); + return TensorShape(shape_vector); +} + +TensorShape InferReshapeOutputShape(const Tensor* src, const Tensor* shape, bool allow_zero) { + ORT_ENFORCE(shape != nullptr, "Cannot reshape to a null shape."); + ORT_ENFORCE(shape->Shape().NumDimensions() == 1, "Shape must be an 1-D tensor."); + ORT_ENFORCE(shape->Location().device.Type() == OrtDevice::CPU, "Shape must be on CPU."); + + return InferReshapeOutputShape( + src->Shape(), + shape->template DataAsSpan(), + allow_zero); +} + +Status FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool /*allow_zero*/, + Tensor* Y) { + if (!X) return Status(common::ONNXRUNTIME, common::FAIL, "Missing data tensor to be reshaped."); + if (!shape) return Status(common::ONNXRUNTIME, common::FAIL, "Missing shape tensor for reshaping."); + if (shape->Shape().NumDimensions() != 1) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, FAIL, "The shape tensor for reshaping must be a vector, but got ", shape->Shape(), "."); + } + if (shape->Location().device.Type() != OrtDevice::CPU) { + return Status(common::ONNXRUNTIME, common::FAIL, "Shape tensor must be on CPU."); + } + + const void* src_data = X->DataRaw(); + void* dst_data = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (src_data != dst_data) { + ORT_ENFORCE(ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(cuda_kernel->CopyTensor(*X, *Y, *ctx->GetComputeStream())); + } + + return Status::OK(); +} + +std::unique_ptr FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool allow_zero) { + // TODO(wechi): Study if Tensor can be created as view to existing tensor. + // This feature can refine code for re-sharding and shape broadcasting. + + ORT_ENFORCE(X != nullptr, "Missing data tensor to be reshaped."); + ORT_ENFORCE(shape != nullptr, "Missing shape tensor for reshaping."); + ORT_ENFORCE(shape->Shape().NumDimensions() == 1, "The shape tensor for reshaping must be a vector, but got ", shape->Shape(), "."); + ORT_ENFORCE(shape->Location().device.Type() == OrtDevice::CPU, "Shape tensor must be on CPU."); + + // Calculate output's shape. + auto dst_shape = InferReshapeOutputShape(X, shape, allow_zero); + + // Pre-allocate output. + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc).IsOK()); + auto Y = Tensor::Create(X->DataType(), dst_shape, alloc); + + // Do reshape. It's equivalent to memcpy. + ORT_ENFORCE(FuncReshape(cuda_kernel, ctx, X, shape, allow_zero, Y.get()).IsOK()); + return Y; +} + ONNX_OPERATOR_KERNEL_EX( Reshape, kOnnxDomain, diff --git a/onnxruntime/core/providers/cuda/tensor/reshape.h b/onnxruntime/core/providers/cuda/tensor/reshape.h index 01e933e65888f..8f33265071ed3 100644 --- a/onnxruntime/core/providers/cuda/tensor/reshape.h +++ b/onnxruntime/core/providers/cuda/tensor/reshape.h @@ -10,6 +10,39 @@ namespace onnxruntime { namespace cuda { +// Deduce output shape from ONNX Reshape's inputs. +// +// Arguments: +// data_tensor_shape: The shape of the data tensor (i.e., 1st input). +// shape_span: Elements in the shape tensor (i.e., 2nd input). +// +// Returns: +// The output shape of this Reshape. No symbolic values such as "-1" or "0". +TensorShape InferReshapeOutputShape( + const TensorShape& data_tensor_shape, + const gsl::span& shape_span, + bool allow_zero); + +TensorShape InferReshapeOutputShape( + const Tensor* src, + const Tensor* shape, + bool allow_zero); + +Status FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool /*allow_zero*/, + Tensor* Y); + +std::unique_ptr FuncReshape( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* X, + const Tensor* shape, + const bool allow_zero); + class Reshape final : public CudaKernel { public: Reshape(const OpKernelInfo& info) : CudaKernel(info), @@ -18,27 +51,11 @@ class Reshape final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override { // Copy the second input tensor into the shape vector - const Tensor* shapeTensor = context->Input(1); - if (shapeTensor == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); - if (shapeTensor->Shape().NumDimensions() != 1) return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "A shape tensor must be a vector tensor, got ", shapeTensor->Shape().NumDimensions(), " dimensions"); - auto data_span = shapeTensor->template DataAsSpan(); - TensorShapeVector shape(data_span.begin(), data_span.end()); - const Tensor* X = context->Input(0); - if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); - const TensorShape& X_shape = X->Shape(); - - ReshapeHelper helper(X_shape, shape, allow_zero_); - - Tensor* Y = context->Output(0, TensorShape(shape)); - const void* source = X->DataRaw(); - void* target = Y->MutableDataRaw(); - // If source and target pointers are not equal (non-inplace operation), we need to copy the data. - if (target != source) { - ORT_ENFORCE(context->GetComputeStream()); - ORT_RETURN_IF_ERROR(CopyTensor(*X, *Y, *context->GetComputeStream())); - } - - return Status::OK(); + const Tensor* data_tensor = context->Input(0); + const Tensor* shape_tensor = context->Input(1); + const auto target_shape = InferReshapeOutputShape(data_tensor, shape_tensor, allow_zero_); + Tensor* output_tensor = context->Output(0, target_shape); + return FuncReshape(this, context, data_tensor, shape_tensor, allow_zero_, output_tensor); } private: diff --git a/onnxruntime/core/providers/rocm/rocm_kernel.h b/onnxruntime/core/providers/rocm/rocm_kernel.h index 463c1cf0d2ea6..c0b7d4722d3e4 100644 --- a/onnxruntime/core/providers/rocm/rocm_kernel.h +++ b/onnxruntime/core/providers/rocm/rocm_kernel.h @@ -173,17 +173,17 @@ class RocmKernel : public OpKernel { return provider_->PerThreadDefaultMiopenHandle(); } + inline Status CopyTensor(const Tensor& src, Tensor& dst, onnxruntime::Stream& stream) const { + auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(src.Location().device, dst.Location().device); + return gpu_data_transfer->CopyTensorAsync(src, dst, stream); + } + protected: template inline const T* GetConstOnes(size_t count, hipStream_t stream) const { return provider_->template GetConstOnes(count, stream); } - inline Status CopyTensor(const Tensor& src, Tensor& dst, onnxruntime::Stream& stream) const { - auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(src.Location().device, dst.Location().device); - return gpu_data_transfer->CopyTensorAsync(src, dst, stream); - } - inline int GetDeviceId() const { return provider_->GetDeviceId(); } private: diff --git a/onnxruntime/test/python/onnxruntime_test_distributed.py b/onnxruntime/test/python/onnxruntime_test_distributed.py index a9b55122c6806..2acca4a8f22ae 100644 --- a/onnxruntime/test/python/onnxruntime_test_distributed.py +++ b/onnxruntime/test/python/onnxruntime_test_distributed.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import unittest +from typing import Tuple import numpy as np import onnxscript @@ -18,6 +19,672 @@ def shard_tensor(X, rank, axis, num_shards): return np.split(X, num_shards, axis)[rank] +def shard_tensor_per_device_mesh(X, rank, axis, device_mesh): + if axis is None: + return X + shards = np.split(X, len(device_mesh), axis) + selected_shards = tuple(shard for device_id, shard in zip(device_mesh, shards) if device_id == rank) + return np.concatenate(selected_shards, axis=axis) + + +def translate_device_mesh_to_attrs(device_mesh: np.ndarray): + device_mesh_shape = "[" + ",".join(str(dim) for dim in device_mesh.shape) + "]" + device_mesh_elements = "[" + ",".join(str(elem) for elem in device_mesh.flat) + "]" + return device_mesh_shape, device_mesh_elements + + +def parse_sharding_spec(spec: str): + axis_conditions = [] + sharding_device_axes = [] + token_index = 0 + while True: + token = spec[token_index] + if token == "R": + axis_conditions.append("R") + sharding_device_axes.append(None) + token_index += 1 + elif token == "S": + axis_conditions.append("S") + # Move token pointer to "["" + token_index += 1 + assert spec[token_index] == "[" + number_tokens = "" + while True: + token_index += 1 + token = spec[token_index] + if token == "]": + break + number_tokens += token + assert spec[token_index] == "]" + # Skip "]" and point to next S/R token + token_index += 1 + sharding_device_axes.append(int(number_tokens)) + else: + raise ValueError(f"Invalid spec: {spec}") + if token_index >= len(spec): + break + return axis_conditions, sharding_device_axes + + +def find_shard_axis(axis_conditions, shard_device_axes): + sharded_axis = None + sharded_axis_count = 0 + for i, cond in enumerate(axis_conditions): + if cond == "S": + sharded_axis = i + sharded_axis_count += 1 + assert sharded_axis_count in (0, 1), "Can shard at most one axis per tensor." + if sharded_axis is not None: + assert shard_device_axes[sharded_axis] == 0, "Device mesh must be 1-D, so 0 is the only valid device mesh axis." + return sharded_axis + + +def shard_tensor_per_spec(tensor: np.ndarray, rank: int, spec: str, device_mesh: np.ndarray): + axis_conditions, shard_device_axes = parse_sharding_spec(spec) + sharded_axis = find_shard_axis(axis_conditions, shard_device_axes) + return shard_tensor_per_device_mesh(tensor, rank, sharded_axis, list(device_mesh.flat)) + + +class TestDistributedReshape(unittest.TestCase): + def _check_distributed_reshape( + self, + shape: Tuple[int, ...], + target_shape: Tuple[int, ...], + input_device_meshs: np.ndarray, + input_shard_specs: Tuple[str, ...], + output_device_meshs: np.ndarray, + output_shard_specs: Tuple[str, ...], + ): + assert all(len(mesh.shape) == 1 for mesh in input_device_meshs) + assert all(len(mesh.shape) == 1 for mesh in output_device_meshs) + assert len(input_device_meshs) == len(input_shard_specs) + assert len(output_device_meshs) == len(output_shard_specs) + + input_device_mesh_shapes = [] + input_device_mesh_elements = [] + for device_mesh in input_device_meshs: + device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) + input_device_mesh_shapes.append(device_mesh_shape) + input_device_mesh_elements.append(device_mesh_element) + + output_device_mesh_shapes = [] + output_device_mesh_elements = [] + for device_mesh in output_device_meshs: + device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) + output_device_mesh_shapes.append(device_mesh_shape) + output_device_mesh_elements.append(device_mesh_element) + + @onnxscript.script() + def distributed_reshape_instance(data_tensor: FLOAT, shape_tensor: INT64): + return MICROSOFT_OPSET.DistributedReshape( + data_tensor, + shape_tensor, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + rank = comm.Get_rank() + data_tensor = np.arange(np.prod(shape), dtype=np.float32).reshape(*shape) + shape_tensor = np.array( + target_shape, + dtype=np.int64, + ) + + local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshs[0]) + assert "S" not in input_shard_specs[1], "Shape should not be sharded." + + expected = np.reshape(data_tensor, shape_tensor) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshs[0]) + + onnx_model = distributed_reshape_instance.to_model_proto( + input_types=[FLOAT[tuple(local_data_tensor.shape)], INT64[tuple(shape_tensor.shape)]], + output_types=[FLOAT[tuple(local_expected.shape)]], + ) + + # Each MPI process owns a sharded model. + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + # Each MPI process executes its sharded model. + # The result is `local` tensor stored on a specific MPI rank + # instead of `logical` tensor. + result = sess.run( + None, + { + "data_tensor": local_data_tensor, + "shape_tensor": shape_tensor, + }, + ) + + # Compare local tensor and the corresponding logical sub-tensor + # obtained by sharding logical tensor following output's sharding spec. + np.testing.assert_allclose(result[0], local_expected, rtol=1e-5, atol=1e-8) + + def test_reshape_two_axis_fusion_shape_2_3_sr_01_shape_6_s_01(self): + # Two axis fusion. + # S[0]R, shape=[2, 3], device_mesh=[0, 1] -> S[0], shape = [6], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=( + 2, + 3, + ), + target_shape=(6,), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]R", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]",), + ) + + def test_reshape_two_axis_fusion_shape_2_4_rs_01_shape_8_s_0101(self): + # Two axis fusion. + # RS[0], shape=[2, 4], device_mesh=[0, 1] -> S[0], shape = [8], device_mesh=[0, 1, 0, 1] + self._check_distributed_reshape( + shape=( + 2, + 4, + ), + target_shape=(8,), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RS[0]", "R"), + output_device_meshs=[np.array([0, 1, 0, 1])], + output_shard_specs=("S[0]",), + ) + + def test_reshape_two_axis_fusion_shape_2_3_5_srr_01_shape_2_15_sr_01(self): + # Two axis fusion. + # S[0]RR, shape=[2, 3, 5], device_mesh=[0, 1] -> S[0]R, shape = [2, 15], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=( + 2, + 3, + 5, + ), + target_shape=( + 2, + 15, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_fusion_shape_2_3_5_rsr_01_shape_2_15_sr_01(self): + # Two axis fusion. + # RS[0]R, shape=[2, 4, 5], device_mesh=[0, 1] -> RS[0], shape = [2, 20], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=( + 2, + 4, + 5, + ), + target_shape=( + 2, + 20, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RS[0]R", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]",), + ) + + def test_reshape_two_axis_fusion_shape_2_3_6_rrs_01_shape_2_18_rs_010101(self): + # Two axis fusion. + # RRS[0], shape=[2, 3, 6], device_mesh=[0, 1] -> RS[0], shape = [2, 18], device_mesh=[0, 1, 0, 1, 0, 1] + self._check_distributed_reshape( + shape=( + 2, + 3, + 6, + ), + target_shape=( + 2, + 18, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RRS[0]", "R"), + output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_shard_specs=("RS[0]",), + ) + # Two axis fusion. + # RRS[0], shape=[2, 3, 8], device_mesh=[0, 1, 0, 1] -> RS[0], shape = [2, 24], device_mesh=[0, 1, 0, 1] * 3 + + # Two axis fusion. + # RS[0]R, shape=[2, 8, 3], device_mesh=[0, 1, 0, 1] -> RS[0], shape = [2, 24], device_mesh=[0, 1, 0, 1] + + def test_reshape_two_axis_decomposition_shape_6_s_01_shape_2_3_sr_01(self): + # Two axis decomposition + # S[0], shape=[6], device_mesh=[0, 1] -> S[0]R, shape=[2, 3], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(6,), + target_shape=( + 2, + 3, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_01_shape_1_16_sr_01(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1] -> RS[0], shape=[1, 16], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 1, + 16, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_01_shape_2_8_sr_01(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1] -> S[0]R, shape=[2, 8], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 2, + 8, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_01_shape_4_4_sr_01(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1] -> S[0]R, shape=[4, 4], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 4, + 4, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_01_shape_8_2_sr_01(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1] -> S[0]R, shape=[8, 2], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 8, + 2, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_01_shape_16_1_sr_01(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1] -> S[0]R, shape=[16, 1], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 16, + 1, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_1_16_sr_0101(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> RS[0], shape=[1, 16], device_mesh=[0, 1, 0, 1] + + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 1, + 16, + ), + input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1, 0, 1])], + output_shard_specs=("RS[0]",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_2_8_rs_01(self): + # Two axis decomposition + # repeats=2 8 = repeats * [unique IDs] + # S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> RS[0], shape=[2, 8], device_mesh=[0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 2, + 8, + ), + input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_4_4_sr_0101(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> S[0]R, shape=[4, 4], device_mesh=[0, 1, 0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 4, + 4, + ), + input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1, 0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_8_2_sr_0101(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> S[0]R, shape=[8, 2], device_mesh=[0, 1, 0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 8, + 2, + ), + input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1, 0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_16_1_sr_0101(self): + # Two axis decomposition + # S[0], shape=[16], device_mesh=[0, 1, 0, 1] -> S[0]R, shape=[16, 1], device_mesh=[0, 1, 0, 1] + self._check_distributed_reshape( + shape=(16,), + target_shape=( + 16, + 1, + ), + input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]", "R"), + output_device_meshs=[np.array([0, 1, 0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_reshape_two_axis_decomposition_shape_21_4096_s_01_shape_3_7_4096_rrs_01(self): + # Two axis decomposition + # [21, 4096] -> [3, 7, 4096] + # data: (21, 2048), (RS, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RRS, [0, 1]) + self._check_distributed_reshape( + shape=( + 21, + 4096, + ), + target_shape=( + 3, + 7, + 4096, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RS[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RRS[0]",), + ) + + def test_reshape_two_axis_decomposition_shape_3_7_4096_rrs_01_shape_3_7_64_64_rrsr_01(self): + # Two axis decomposition + # [3, 7, 4096] -> [3, 7, 64, 64] + # data: (3, 7, 2048), (RRS, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RRSR, [0, 1]) + + self._check_distributed_reshape( + shape=( + 3, + 7, + 4096, + ), + target_shape=( + 3, + 7, + 64, + 64, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RRS[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RRS[0]R",), + ) + + def test_reshape_two_axis_fusion_shape_3_7_4096_rrr_01_shape_21_4906_rr_01(self): + # Two axis fusion + # [3, 7, 4096] -> [21, 4096] + # data: (3, 7, 4096), (RRR, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RR, [0, 1]) + self._check_distributed_reshape( + shape=( + 3, + 7, + 4096, + ), + target_shape=( + 21, + 4096, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RRR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RR",), + ) + + def test_reshape_two_axis_fusion_shape_21_4096_rrr_01_shape_3_7_4906_rr_01(self): + # Two axis fusion + # [21, 4096] -> [3, 7, 4096] + # data: (21, 4096), (RR, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RRR, [0, 1]) + self._check_distributed_reshape( + shape=( + 21, + 4096, + ), + target_shape=( + 3, + 7, + 4096, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RRR",), + ) + + def test_reshape_two_axis_fusion_shape_3_64_7_64_rsrr_01_shape_192_7_64_srr_010101(self): + # Two axis fusion + # [3, 64, 7, 64] -> [192, 7, 64] + # data: (3, 32, 7, 64), (RSRR, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (SRR, [0, 1, 0, 1, 0, 1]) + + self._check_distributed_reshape( + shape=( + 3, + 64, + 7, + 64, + ), + target_shape=( + 192, + 7, + 64, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RS[0]RR", "R"), + output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_shard_specs=("S[0]RR",), + ) + + def test_reshape_two_axis_decomposition_shape_192_7_7_srr_010101_shape_3_64_7_7_rsrr_01(self): + # Two axis decomposition + # [192, 7, 7] -> [3, 64, 7, 7] + # data: (96, 7, 7), (SRR, [0, 1, 0, 1, 0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RSRR, [0.0, 1.0]) + + self._check_distributed_reshape( + shape=( + 192, + 7, + 7, + ), + target_shape=( + 3, + 64, + 7, + 7, + ), + input_device_meshs=[np.array([0, 1, 0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]RR",), + ) + + def test_reshape_two_axis_fusion_shape_3_64_7_7_rsrr_01_shape_192_7_7_srr_010101(self): + # Two axis fusion + # [3, 64, 7, 7] -> [192, 7, 7] + # data: (3, 32, 7, 7), (RSRR, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (SRR, [0, 1, 0, 1, 0, 1]) + + self._check_distributed_reshape( + shape=( + 3, + 64, + 7, + 7, + ), + target_shape=( + 192, + 7, + 7, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RS[0]RR", "R"), + output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_shard_specs=("S[0]RR",), + ) + + def test_reshape_two_axis_decomposition_shape_192_7_64_srr_010101_shape_3_64_7_64_rsrr_01(self): + # Two axis decomposition + # [192, 7, 64] -> [3, 64, 7, 64] + # data: (96, 7, 64), (SRR, [0, 1, 0, 1, 0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RSRR, [0.0, 1.0]) + + self._check_distributed_reshape( + shape=( + 192, + 7, + 64, + ), + target_shape=( + 3, + 64, + 7, + 64, + ), + input_device_meshs=[np.array([0, 1, 0, 1, 0, 1])] * 2, + input_shard_specs=("S[0]RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]RR",), + ) + + def test_reshape_two_axis_fusion_shape_3_7_64_64_rrsr_01_shape_3_7_4096_rrs_01(self): + # Two axis fusion + # [3, 7, 64, 64] -> [3, 7, 4096] + # data: (3, 7, 32, 64), (RRSR, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RRS, [0, 1]) + + self._check_distributed_reshape( + shape=( + 3, + 7, + 64, + 64, + ), + target_shape=( + 3, + 7, + 4096, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RRS[0]R", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RRS[0]",), + ) + + def test_reshape_two_axis_fusion_shape_3_7_4096_rrs_01_shape_21_4906_rs_01(self): + # Two axis fusion + # [3, 7, 4096] -> [21, 4096] + # data: (3, 7, 2048), (RRS, [0, 1]) + # shape: None, (R, [0, 1]) + # reshaped: None, None + # ----------------------------------- + # new reshaped: None, (RS, [0, 1]) + self._check_distributed_reshape( + shape=( + 3, + 7, + 4096, + ), + target_shape=( + 21, + 4096, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RRS[0]", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]",), + ) + + class TestDistributed(unittest.TestCase): def test_matmul_rs_sr_rr(self): # It means 1-D tensor with single element: [2].