From 9e8ad398479d9c2dc0ca91a8df89e452d059f6ee Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Wed, 1 Nov 2023 08:49:33 -0700 Subject: [PATCH] Distributed Reduction (#18206) This PR implements distributed reduciton for llama 2. This version doesn't consider any cases requring re-sharding because we haven't seen any use cases. Intutive examples: - [supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] -> Reduce(axes=[0]) -> [1,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] - [supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] -> Reduce(axes=[1]) -> [2,1,6]-tensor with spec=RRS[0] and device_mesh=[0,1] - [not supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] -> Reduce(axes=[2]) -> [2,4,1]-tensor with spec=RRS[0] and device_mesh=[0,1] Algorithm: When the reduced axes are not sharded, each device can call reduction directly. The output sharding spec will be identical to input sharding spec. We currently throw when input and output sharding specs are different. Review guideline: - Check 97b8d2f for new op's schema and how new op is registered. - Read tests in 2450f93 to get faimilar with the behavior of these ops. - Check the implementation details in 753d9af. --- cmake/onnxruntime_providers_cuda.cmake | 1 + cmake/onnxruntime_rocm_hipify.cmake | 1 + .../cuda/collective/distributed_reduce.cc | 175 +++++++++ .../cuda/collective/distributed_reduce.h | 59 +++ .../contrib_ops/cuda/cuda_contrib_kernels.cc | 18 + .../core/graph/contrib_ops/collective_defs.cc | 123 +++++++ .../providers/cuda/reduction/reduction_ops.cc | 24 ++ .../python/onnxruntime_test_distributed.py | 345 ++++++++++++------ 8 files changed, 638 insertions(+), 108 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc create mode 100644 onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 043789c36c327..ce0c12804b08a 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -40,6 +40,7 @@ "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reshape.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_expand.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reduce.cc" ) endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 6ccf063c71290..9bc2bdd208a92 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -109,6 +109,7 @@ if (NOT onnxruntime_USE_NCCL) list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_reduce.cc") endif() set(provider_excluded_files diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc new file mode 100644 index 0000000000000..967f30a304ac2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc @@ -0,0 +1,175 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_reduce.h" +#include "sharding.h" +#include "sharding_spec.h" +#include "nccl_kernels.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/reduction/reduction_ops.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +DistributedReduceBase::DistributedReduceBase( + const OpKernelInfo& info, + cudnnReduceTensorOp_t cudnn_reduce_op) : DistributedKernel(info) { + keepdims_ = info.GetAttrOrDefault("keepdims", 1); + cudnn_reduce_op_ = cudnn_reduce_op; +}; + +template +Status DistributedReduceBase::ComputeInternal(OpKernelContext* context) const { + const auto& input_sharding_spec = input_shard_specs_.at(0); + const auto& axes_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + ORT_ENFORCE(axes_sharding_spec.HasNoShard(), + "It's not worthy to shard axes tensor. " + "If sharding axes is needed, please submit a feature request."); + + const Tensor* input_tensor = context->Input(0); + const Tensor* axes_tensor = context->Input(1); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "Axes tensor must be an 1-D tensor."); + auto axes_span = axes_tensor->DataAsSpan(); + + // Case 1: empty axes means treating this reduction as an identity. + if (axes_span.empty()) { + ORT_ENFORCE( + input_sharding_spec == output_sharding_spec, + "Input and output sharding specs should be the same. Otherwise, resharding is needed."); + auto* output_tensor = context->Output(0, input_tensor->Shape()); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor->MutableData(), input_tensor->Data(), input_tensor->SizeInBytes(), + cudaMemcpyDeviceToDevice, Stream(context))); + return Status::OK(); + } + + // Case 2: this is a valid reduction. Let's prepare for it. + + bool sharding_on_reduced_axes = false; + for (auto axis_it = axes_span.begin(); input_sharding_spec.HasShard() && axis_it != axes_span.end(); ++axis_it) { + if (*axis_it == input_sharding_spec.GetPartitionAxis()) { + sharding_on_reduced_axes = true; + break; + } + } + + if (sharding_on_reduced_axes) { + // Case 2-1: sharding on reduced axes. + ORT_THROW(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "Not implemented. Resharding is required to make reduced axes replica."); + } else { + // Case 2-2: sharding on passing-through axes or no shard. + ORT_ENFORCE( + input_sharding_spec == output_sharding_spec, + "Input and output sharding specs should be the same. Otherwise, resharding is needed."); + onnxruntime::cuda::PrepareReduceMetadata metadata; + ORT_RETURN_IF_ERROR( + onnxruntime::cuda::PrepareForReduce(input_tensor, keepdims_, axes_span, metadata)); + auto output_tensor = context->Output(0, metadata.squeezed_output_dims); + + // Fast reduction is not deterministic, so sometimes we want to turn it off. + const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute(); + return onnxruntime::cuda::ReduceComputeCore( + /* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault), + *input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span, + /* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false, + enable_fast_but_non_deterministic_reduction, context->GetComputeStream()); + } + return Status::OK(); +} + +template +DistributedReduceSum::DistributedReduceSum( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_ADD){}; + +template +DistributedReduceMean::DistributedReduceMean( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_AVG){}; + +template +DistributedReduceMax::DistributedReduceMax( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_MAX){}; + +// ReduceSum +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceSum, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceSum); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceSum, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceSum); + +// ReduceMean +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMean, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMean); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMean, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMean); + +// ReduceMax +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMax, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMax); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMax, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMax); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h new file mode 100644 index 0000000000000..2939852c75c60 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "sharding.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedReduceBase : public DistributedKernel { + public: + explicit DistributedReduceBase(const OpKernelInfo& info, cudnnReduceTensorOp_t cudnn_reduce_op); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + // ONNX attribute. If true, reduced axes are retained as dimensions with size one. + // Otherwise, drop reduced axes. + bool keepdims_; + cudnnReduceTensorOp_t cudnn_reduce_op_; +}; + +template +class DistributedReduceSum final : public DistributedReduceBase { + public: + explicit DistributedReduceSum(const OpKernelInfo& info); +}; + +template +class DistributedReduceMean final : public DistributedReduceBase { + public: + explicit DistributedReduceMean(const OpKernelInfo& info); +}; + +template +class DistributedReduceMax final : public DistributedReduceBase { + public: + explicit DistributedReduceMax(const OpKernelInfo& info); +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index d51915b85095f..8e157da6cb43f 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -175,6 +175,15 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceSum); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMax); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMean); #endif template <> @@ -354,6 +363,15 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 070df487a264d..8b5b561c1ad87 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -273,6 +273,129 @@ void RegisterCollectiveOps() { OpSchema::NonDifferentiable) .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceSum) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceMax) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceMean) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); } } // namespace contrib diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index d46ed9c245a8e..bc78e577c5052 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -614,6 +614,30 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, return Status::OK(); } +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + template template Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, cudnnReduceTensorOp_t cudnn_reduce_op) const { diff --git a/onnxruntime/test/python/onnxruntime_test_distributed.py b/onnxruntime/test/python/onnxruntime_test_distributed.py index e0fb3979a9f55..6f691972181b5 100644 --- a/onnxruntime/test/python/onnxruntime_test_distributed.py +++ b/onnxruntime/test/python/onnxruntime_test_distributed.py @@ -7,7 +7,7 @@ import numpy as np import onnxscript from mpi4py import MPI -from onnxscript import FLOAT, INT64 +from onnxscript import FLOAT, FLOAT16, INT64 import onnxruntime as ort @@ -27,12 +27,23 @@ def shard_tensor_per_device_mesh(X, rank, axis, device_mesh): return np.concatenate(selected_shards, axis=axis) -def translate_device_mesh_to_attrs(device_mesh: np.ndarray): +def translate_single_device_mesh(device_mesh: np.ndarray): device_mesh_shape = "[" + ",".join(str(dim) for dim in device_mesh.shape) + "]" device_mesh_elements = "[" + ",".join(str(elem) for elem in device_mesh.flat) + "]" return device_mesh_shape, device_mesh_elements +def translate_all_device_meshes(device_meshes: np.ndarray): + assert all(len(mesh.shape) == 1 for mesh in device_meshes) + device_mesh_shapes = [] + device_mesh_elements = [] + for device_mesh in device_meshes: + device_mesh_shape, device_mesh_element = translate_single_device_mesh(device_mesh) + device_mesh_shapes.append(device_mesh_shape) + device_mesh_elements.append(device_mesh_element) + return device_mesh_shapes, device_mesh_elements + + def parse_sharding_spec(spec: str): axis_conditions = [] sharding_device_axes = [] @@ -90,29 +101,13 @@ def _check_distributed_reshape( self, shape: Tuple[int, ...], target_shape: Tuple[int, ...], - input_device_meshs: np.ndarray, + input_device_meshes: np.ndarray, input_shard_specs: Tuple[str, ...], - output_device_meshs: np.ndarray, + output_device_meshes: np.ndarray, output_shard_specs: Tuple[str, ...], ): - assert all(len(mesh.shape) == 1 for mesh in input_device_meshs) - assert all(len(mesh.shape) == 1 for mesh in output_device_meshs) - assert len(input_device_meshs) == len(input_shard_specs) - assert len(output_device_meshs) == len(output_shard_specs) - - input_device_mesh_shapes = [] - input_device_mesh_elements = [] - for device_mesh in input_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - input_device_mesh_shapes.append(device_mesh_shape) - input_device_mesh_elements.append(device_mesh_element) - - output_device_mesh_shapes = [] - output_device_mesh_elements = [] - for device_mesh in output_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - output_device_mesh_shapes.append(device_mesh_shape) - output_device_mesh_elements.append(device_mesh_element) + input_device_mesh_shapes, input_device_mesh_elements = translate_all_device_meshes(input_device_meshes) + output_device_mesh_shapes, output_device_mesh_elements = translate_all_device_meshes(output_device_meshes) @onnxscript.script() def distributed_reshape_instance(data_tensor: FLOAT, shape_tensor: INT64): @@ -134,11 +129,11 @@ def distributed_reshape_instance(data_tensor: FLOAT, shape_tensor: INT64): dtype=np.int64, ) - local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshs[0]) + local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshes[0]) assert "S" not in input_shard_specs[1], "Shape should not be sharded." expected = np.reshape(data_tensor, shape_tensor) - local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshs[0]) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshes[0]) onnx_model = distributed_reshape_instance.to_model_proto( input_types=[FLOAT[tuple(local_data_tensor.shape)], INT64[tuple(shape_tensor.shape)]], @@ -176,9 +171,9 @@ def test_reshape_two_axis_fusion_shape_2_3_sr_01_shape_6_s_01(self): 3, ), target_shape=(6,), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]R", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]",), ) @@ -191,9 +186,9 @@ def test_reshape_two_axis_fusion_shape_2_4_rs_01_shape_8_s_0101(self): 4, ), target_shape=(8,), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]",), ) @@ -210,9 +205,9 @@ def test_reshape_two_axis_fusion_shape_2_3_5_srr_01_shape_2_15_sr_01(self): 2, 15, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -229,9 +224,9 @@ def test_reshape_two_axis_fusion_shape_2_3_5_rsr_01_shape_2_15_sr_01(self): 2, 20, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]R", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -248,9 +243,9 @@ def test_reshape_two_axis_fusion_shape_2_3_6_rrs_01_shape_2_18_rs_010101(self): 2, 18, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1, 0, 1])], output_shard_specs=("RS[0]",), ) # Two axis fusion. @@ -268,9 +263,9 @@ def test_reshape_two_axis_decomposition_shape_6_s_01_shape_2_3_sr_01(self): 2, 3, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -283,9 +278,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_1_16_sr_01(self): 1, 16, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -298,9 +293,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_2_8_sr_01(self): 2, 8, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -313,9 +308,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_4_4_sr_01(self): 4, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -328,9 +323,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_8_2_sr_01(self): 8, 2, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -343,9 +338,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_16_1_sr_01(self): 16, 1, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -359,9 +354,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_1_16_sr_0101(self) 1, 16, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("RS[0]",), ) @@ -375,9 +370,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_2_8_rs_01(self): 2, 8, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -390,9 +385,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_4_4_sr_0101(self): 4, 4, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]R",), ) @@ -405,9 +400,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_8_2_sr_0101(self): 8, 2, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]R",), ) @@ -420,9 +415,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_16_1_sr_0101(self) 16, 1, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]R",), ) @@ -444,9 +439,9 @@ def test_reshape_two_axis_decomposition_shape_21_4096_s_01_shape_3_7_4096_rrs_01 7, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRS[0]",), ) @@ -471,9 +466,9 @@ def test_reshape_two_axis_decomposition_shape_3_7_4096_rrs_01_shape_3_7_64_64_rr 64, 64, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRS[0]R",), ) @@ -495,9 +490,9 @@ def test_reshape_two_axis_fusion_shape_3_7_4096_rrr_01_shape_21_4906_rr_01(self) 21, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RR",), ) @@ -519,9 +514,9 @@ def test_reshape_two_axis_fusion_shape_21_4096_rrr_01_shape_3_7_4906_rr_01(self) 7, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRR",), ) @@ -546,9 +541,9 @@ def test_reshape_two_axis_fusion_shape_3_64_7_64_rsrr_01_shape_192_7_64_srr_0101 7, 64, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]RR", "R"), - output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1, 0, 1])], output_shard_specs=("S[0]RR",), ) @@ -573,9 +568,9 @@ def test_reshape_two_axis_decomposition_shape_192_7_7_srr_010101_shape_3_64_7_7_ 7, 7, ), - input_device_meshs=[np.array([0, 1, 0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1, 0, 1])] * 2, input_shard_specs=("S[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]RR",), ) @@ -600,9 +595,9 @@ def test_reshape_two_axis_fusion_shape_3_64_7_7_rsrr_01_shape_192_7_7_srr_010101 7, 7, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]RR", "R"), - output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1, 0, 1])], output_shard_specs=("S[0]RR",), ) @@ -627,9 +622,9 @@ def test_reshape_two_axis_decomposition_shape_192_7_64_srr_010101_shape_3_64_7_6 7, 64, ), - input_device_meshs=[np.array([0, 1, 0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1, 0, 1])] * 2, input_shard_specs=("S[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]RR",), ) @@ -654,9 +649,9 @@ def test_reshape_two_axis_fusion_shape_3_7_64_64_rrsr_01_shape_3_7_4096_rrs_01(s 7, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]R", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRS[0]",), ) @@ -678,9 +673,9 @@ def test_reshape_two_axis_fusion_shape_3_7_4096_rrs_01_shape_21_4906_rs_01(self) 21, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -690,29 +685,16 @@ def _check_distributed_expand( self, shape: Tuple[int, ...], target_shape: Tuple[int, ...], - input_device_meshs: np.ndarray, + input_device_meshes: np.ndarray, input_shard_specs: Tuple[str, ...], - output_device_meshs: np.ndarray, + output_device_meshes: np.ndarray, output_shard_specs: Tuple[str, ...], ): - assert all(len(mesh.shape) == 1 for mesh in input_device_meshs) - assert all(len(mesh.shape) == 1 for mesh in output_device_meshs) - assert len(input_device_meshs) == len(input_shard_specs) - assert len(output_device_meshs) == len(output_shard_specs) - - input_device_mesh_shapes = [] - input_device_mesh_elements = [] - for device_mesh in input_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - input_device_mesh_shapes.append(device_mesh_shape) - input_device_mesh_elements.append(device_mesh_element) - - output_device_mesh_shapes = [] - output_device_mesh_elements = [] - for device_mesh in output_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - output_device_mesh_shapes.append(device_mesh_shape) - output_device_mesh_elements.append(device_mesh_element) + assert len(input_device_meshes) == len(input_shard_specs) + assert len(output_device_meshes) == len(output_shard_specs) + + input_device_mesh_shapes, input_device_mesh_elements = translate_all_device_meshes(input_device_meshes) + output_device_mesh_shapes, output_device_mesh_elements = translate_all_device_meshes(output_device_meshes) @onnxscript.script() def distributed_expand_instance(data_tensor: FLOAT, shape_tensor: INT64): @@ -734,11 +716,11 @@ def distributed_expand_instance(data_tensor: FLOAT, shape_tensor: INT64): dtype=np.int64, ) - local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshs[0]) + local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshes[0]) assert "S" not in input_shard_specs[1], "Shape should not be sharded." expected = data_tensor * np.ones(shape_tensor) - local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshs[0]) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshes[0]) onnx_model = distributed_expand_instance.to_model_proto( input_types=[FLOAT[tuple(local_data_tensor.shape)], INT64[tuple(shape_tensor.shape)]], @@ -780,9 +762,9 @@ def test_expand_sharded_on_expanded_axis(self): 8, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -799,9 +781,9 @@ def test_expand_sharded_on_expanded_axis_with_device_mesh_0101(self): 8, 8, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("RS[0]",), ) @@ -818,9 +800,9 @@ def test_expand_replicated_on_expanded_axis(self): 1, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RR",), ) @@ -837,12 +819,12 @@ def test_expand_with_pass_through_sharding_spec(self): 1, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=( "S[0]R", "R", ), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -863,13 +845,160 @@ def test_expand_in_tiny_llama(self): 256, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]RR",), ) +class TestDistributedReduce(unittest.TestCase): + def _check_distributed_reduce( + self, + keepdims: int, + dtype: np.dtype, + shape: Tuple[int, ...], + axes: Tuple[int, ...], + input_device_meshes: np.ndarray, + input_shard_specs: Tuple[str, ...], + output_device_meshes: np.ndarray, + output_shard_specs: Tuple[str, ...], + ): + assert len(input_device_meshes) == len(input_shard_specs) + assert len(output_device_meshes) == len(output_shard_specs) + + input_device_mesh_shapes, input_device_mesh_elements = translate_all_device_meshes(input_device_meshes) + output_device_mesh_shapes, output_device_mesh_elements = translate_all_device_meshes(output_device_meshes) + + @onnxscript.script() + def distributed_reduce_sum_instance(data_tensor: FLOAT, axes_tensor: INT64): + return MICROSOFT_OPSET.DistributedReduceSum( + data_tensor, + axes_tensor, + keepdims=keepdims, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + @onnxscript.script() + def distributed_reduce_max_instance(data_tensor: FLOAT, axes_tensor: INT64): + return MICROSOFT_OPSET.DistributedReduceMax( + data_tensor, + axes_tensor, + keepdims=keepdims, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + @onnxscript.script() + def distributed_reduce_mean_instance(data_tensor: FLOAT, axes_tensor: INT64): + return MICROSOFT_OPSET.DistributedReduceMean( + data_tensor, + axes_tensor, + keepdims=keepdims, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + rank = comm.Get_rank() + + for onnx_func, np_func in zip( + [distributed_reduce_sum_instance, distributed_reduce_max_instance, distributed_reduce_mean_instance], + [np.sum, np.maximum.reduce, np.mean], + ): + data = np.random.randint(4, size=shape).astype(dtype) + expected = np_func(data, axis=axes, keepdims=bool(keepdims)) + + assert len(input_shard_specs) == 2 and len(input_device_meshes) == 2, "Reduce has two inputs." + assert "S" not in input_shard_specs[1], "Tensor `axes` should not be sharded." + assert len(output_shard_specs) == 1 and len(output_device_meshes) == 1, "Reduce has only one output." + + local_data = shard_tensor_per_spec(data, rank, input_shard_specs[0], input_device_meshes[0]) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshes[0]) + + if dtype == np.float32: + onnx_model = onnx_func.to_model_proto( + input_types=[FLOAT[tuple(local_data.shape)], INT64[len(axes)]], + output_types=[FLOAT[tuple(local_expected.shape)]], + ) + elif dtype == np.int64: + onnx_model = onnx_func.to_model_proto( + input_types=[INT64[tuple(local_data.shape)], INT64[len(axes)]], + output_types=[INT64[tuple(local_expected.shape)]], + ) + elif dtype == np.float16: + onnx_model = onnx_func.to_model_proto( + input_types=[FLOAT16[tuple(local_data.shape)], INT64[len(axes)]], + output_types=[FLOAT16[tuple(local_expected.shape)]], + ) + else: + raise RuntimeError(f"Unsupported dtype: {dtype}") + + # Each MPI process owns a sharded model. + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + # Each MPI process executes its sharded model. + # The result is `local` tensor stored on a specific MPI rank + # instead of `logical` tensor. + result = sess.run( + None, + { + "data_tensor": local_data, + "axes_tensor": np.array(axes, dtype=np.int64), + }, + ) + + # Compare local tensor and the corresponding logical sub-tensor + # obtained by sharding logical tensor following output's sharding spec. + np.testing.assert_allclose(result[0], local_expected, rtol=1e-5, atol=1e-8) + + def test_reduce(self): + self._check_distributed_reduce( + keepdims=1, + dtype=np.float32, + shape=( + 8, + 4, + ), + axes=(0,), + input_device_meshes=[np.array([0, 1])] * 2, + input_shard_specs=("RR", "R"), + output_device_meshes=[np.array([0, 1])], + output_shard_specs=("RR",), + ) + + def test_reduce_sharded(self): + self._check_distributed_reduce( + keepdims=1, + dtype=np.float32, + shape=( + 8, + 4, + ), + axes=(1,), + input_device_meshes=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]R", "R"), + output_device_meshes=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + class TestDistributed(unittest.TestCase): def test_matmul_rs_sr_rr(self): # It means 1-D tensor with single element: [2].