diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 043789c36c327..ce0c12804b08a 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -40,6 +40,7 @@ "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reshape.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_expand.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reduce.cc" ) endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 6ccf063c71290..9bc2bdd208a92 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -109,6 +109,7 @@ if (NOT onnxruntime_USE_NCCL) list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_reduce.cc") endif() set(provider_excluded_files diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc new file mode 100644 index 0000000000000..967f30a304ac2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc @@ -0,0 +1,175 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_reduce.h" +#include "sharding.h" +#include "sharding_spec.h" +#include "nccl_kernels.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/reduction/reduction_ops.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +DistributedReduceBase::DistributedReduceBase( + const OpKernelInfo& info, + cudnnReduceTensorOp_t cudnn_reduce_op) : DistributedKernel(info) { + keepdims_ = info.GetAttrOrDefault("keepdims", 1); + cudnn_reduce_op_ = cudnn_reduce_op; +}; + +template +Status DistributedReduceBase::ComputeInternal(OpKernelContext* context) const { + const auto& input_sharding_spec = input_shard_specs_.at(0); + const auto& axes_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + ORT_ENFORCE(axes_sharding_spec.HasNoShard(), + "It's not worthy to shard axes tensor. " + "If sharding axes is needed, please submit a feature request."); + + const Tensor* input_tensor = context->Input(0); + const Tensor* axes_tensor = context->Input(1); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "Axes tensor must be an 1-D tensor."); + auto axes_span = axes_tensor->DataAsSpan(); + + // Case 1: empty axes means treating this reduction as an identity. + if (axes_span.empty()) { + ORT_ENFORCE( + input_sharding_spec == output_sharding_spec, + "Input and output sharding specs should be the same. Otherwise, resharding is needed."); + auto* output_tensor = context->Output(0, input_tensor->Shape()); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor->MutableData(), input_tensor->Data(), input_tensor->SizeInBytes(), + cudaMemcpyDeviceToDevice, Stream(context))); + return Status::OK(); + } + + // Case 2: this is a valid reduction. Let's prepare for it. + + bool sharding_on_reduced_axes = false; + for (auto axis_it = axes_span.begin(); input_sharding_spec.HasShard() && axis_it != axes_span.end(); ++axis_it) { + if (*axis_it == input_sharding_spec.GetPartitionAxis()) { + sharding_on_reduced_axes = true; + break; + } + } + + if (sharding_on_reduced_axes) { + // Case 2-1: sharding on reduced axes. + ORT_THROW(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "Not implemented. Resharding is required to make reduced axes replica."); + } else { + // Case 2-2: sharding on passing-through axes or no shard. + ORT_ENFORCE( + input_sharding_spec == output_sharding_spec, + "Input and output sharding specs should be the same. Otherwise, resharding is needed."); + onnxruntime::cuda::PrepareReduceMetadata metadata; + ORT_RETURN_IF_ERROR( + onnxruntime::cuda::PrepareForReduce(input_tensor, keepdims_, axes_span, metadata)); + auto output_tensor = context->Output(0, metadata.squeezed_output_dims); + + // Fast reduction is not deterministic, so sometimes we want to turn it off. + const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute(); + return onnxruntime::cuda::ReduceComputeCore( + /* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault), + *input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span, + /* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false, + enable_fast_but_non_deterministic_reduction, context->GetComputeStream()); + } + return Status::OK(); +} + +template +DistributedReduceSum::DistributedReduceSum( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_ADD){}; + +template +DistributedReduceMean::DistributedReduceMean( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_AVG){}; + +template +DistributedReduceMax::DistributedReduceMax( + const OpKernelInfo& info) : DistributedReduceBase(info, CUDNN_REDUCE_TENSOR_MAX){}; + +// ReduceSum +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceSum, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceSum); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceSum, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceSum); + +// ReduceMean +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMean, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMean); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMean, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMean); + +// ReduceMax +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMax, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMax); +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedReduceMax, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedReduceMax); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h new file mode 100644 index 0000000000000..2939852c75c60 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "sharding.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedReduceBase : public DistributedKernel { + public: + explicit DistributedReduceBase(const OpKernelInfo& info, cudnnReduceTensorOp_t cudnn_reduce_op); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + // ONNX attribute. If true, reduced axes are retained as dimensions with size one. + // Otherwise, drop reduced axes. + bool keepdims_; + cudnnReduceTensorOp_t cudnn_reduce_op_; +}; + +template +class DistributedReduceSum final : public DistributedReduceBase { + public: + explicit DistributedReduceSum(const OpKernelInfo& info); +}; + +template +class DistributedReduceMean final : public DistributedReduceBase { + public: + explicit DistributedReduceMean(const OpKernelInfo& info); +}; + +template +class DistributedReduceMax final : public DistributedReduceBase { + public: + explicit DistributedReduceMax(const OpKernelInfo& info); +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 2618fe4a238bd..42afb0ac26d46 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -174,6 +174,15 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceSum); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMax); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMean); #endif template <> @@ -352,6 +361,15 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 070df487a264d..8b5b561c1ad87 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -273,6 +273,129 @@ void RegisterCollectiveOps() { OpSchema::NonDifferentiable) .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceSum) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceMax) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedReduceMean) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Attr("keepdims", + "Keep the reduced dimension or not, default 1 mean keep reduced dimension.", + AttributeProto::INT, + static_cast(1)) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); } } // namespace contrib diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index d46ed9c245a8e..bc78e577c5052 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -614,6 +614,30 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, return Status::OK(); } +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + template template Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, cudnnReduceTensorOp_t cudnn_reduce_op) const { diff --git a/onnxruntime/test/python/onnxruntime_test_distributed.py b/onnxruntime/test/python/onnxruntime_test_distributed.py index e0fb3979a9f55..6f691972181b5 100644 --- a/onnxruntime/test/python/onnxruntime_test_distributed.py +++ b/onnxruntime/test/python/onnxruntime_test_distributed.py @@ -7,7 +7,7 @@ import numpy as np import onnxscript from mpi4py import MPI -from onnxscript import FLOAT, INT64 +from onnxscript import FLOAT, FLOAT16, INT64 import onnxruntime as ort @@ -27,12 +27,23 @@ def shard_tensor_per_device_mesh(X, rank, axis, device_mesh): return np.concatenate(selected_shards, axis=axis) -def translate_device_mesh_to_attrs(device_mesh: np.ndarray): +def translate_single_device_mesh(device_mesh: np.ndarray): device_mesh_shape = "[" + ",".join(str(dim) for dim in device_mesh.shape) + "]" device_mesh_elements = "[" + ",".join(str(elem) for elem in device_mesh.flat) + "]" return device_mesh_shape, device_mesh_elements +def translate_all_device_meshes(device_meshes: np.ndarray): + assert all(len(mesh.shape) == 1 for mesh in device_meshes) + device_mesh_shapes = [] + device_mesh_elements = [] + for device_mesh in device_meshes: + device_mesh_shape, device_mesh_element = translate_single_device_mesh(device_mesh) + device_mesh_shapes.append(device_mesh_shape) + device_mesh_elements.append(device_mesh_element) + return device_mesh_shapes, device_mesh_elements + + def parse_sharding_spec(spec: str): axis_conditions = [] sharding_device_axes = [] @@ -90,29 +101,13 @@ def _check_distributed_reshape( self, shape: Tuple[int, ...], target_shape: Tuple[int, ...], - input_device_meshs: np.ndarray, + input_device_meshes: np.ndarray, input_shard_specs: Tuple[str, ...], - output_device_meshs: np.ndarray, + output_device_meshes: np.ndarray, output_shard_specs: Tuple[str, ...], ): - assert all(len(mesh.shape) == 1 for mesh in input_device_meshs) - assert all(len(mesh.shape) == 1 for mesh in output_device_meshs) - assert len(input_device_meshs) == len(input_shard_specs) - assert len(output_device_meshs) == len(output_shard_specs) - - input_device_mesh_shapes = [] - input_device_mesh_elements = [] - for device_mesh in input_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - input_device_mesh_shapes.append(device_mesh_shape) - input_device_mesh_elements.append(device_mesh_element) - - output_device_mesh_shapes = [] - output_device_mesh_elements = [] - for device_mesh in output_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - output_device_mesh_shapes.append(device_mesh_shape) - output_device_mesh_elements.append(device_mesh_element) + input_device_mesh_shapes, input_device_mesh_elements = translate_all_device_meshes(input_device_meshes) + output_device_mesh_shapes, output_device_mesh_elements = translate_all_device_meshes(output_device_meshes) @onnxscript.script() def distributed_reshape_instance(data_tensor: FLOAT, shape_tensor: INT64): @@ -134,11 +129,11 @@ def distributed_reshape_instance(data_tensor: FLOAT, shape_tensor: INT64): dtype=np.int64, ) - local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshs[0]) + local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshes[0]) assert "S" not in input_shard_specs[1], "Shape should not be sharded." expected = np.reshape(data_tensor, shape_tensor) - local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshs[0]) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshes[0]) onnx_model = distributed_reshape_instance.to_model_proto( input_types=[FLOAT[tuple(local_data_tensor.shape)], INT64[tuple(shape_tensor.shape)]], @@ -176,9 +171,9 @@ def test_reshape_two_axis_fusion_shape_2_3_sr_01_shape_6_s_01(self): 3, ), target_shape=(6,), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]R", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]",), ) @@ -191,9 +186,9 @@ def test_reshape_two_axis_fusion_shape_2_4_rs_01_shape_8_s_0101(self): 4, ), target_shape=(8,), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]",), ) @@ -210,9 +205,9 @@ def test_reshape_two_axis_fusion_shape_2_3_5_srr_01_shape_2_15_sr_01(self): 2, 15, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -229,9 +224,9 @@ def test_reshape_two_axis_fusion_shape_2_3_5_rsr_01_shape_2_15_sr_01(self): 2, 20, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]R", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -248,9 +243,9 @@ def test_reshape_two_axis_fusion_shape_2_3_6_rrs_01_shape_2_18_rs_010101(self): 2, 18, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1, 0, 1])], output_shard_specs=("RS[0]",), ) # Two axis fusion. @@ -268,9 +263,9 @@ def test_reshape_two_axis_decomposition_shape_6_s_01_shape_2_3_sr_01(self): 2, 3, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -283,9 +278,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_1_16_sr_01(self): 1, 16, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -298,9 +293,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_2_8_sr_01(self): 2, 8, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -313,9 +308,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_4_4_sr_01(self): 4, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -328,9 +323,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_8_2_sr_01(self): 8, 2, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -343,9 +338,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_01_shape_16_1_sr_01(self): 16, 1, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -359,9 +354,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_1_16_sr_0101(self) 1, 16, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("RS[0]",), ) @@ -375,9 +370,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_2_8_rs_01(self): 2, 8, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -390,9 +385,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_4_4_sr_0101(self): 4, 4, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]R",), ) @@ -405,9 +400,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_8_2_sr_0101(self): 8, 2, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]R",), ) @@ -420,9 +415,9 @@ def test_reshape_two_axis_decomposition_shape_16_s_0101_shape_16_1_sr_0101(self) 16, 1, ), - input_device_meshs=[np.array([0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1])] * 2, input_shard_specs=("S[0]", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("S[0]R",), ) @@ -444,9 +439,9 @@ def test_reshape_two_axis_decomposition_shape_21_4096_s_01_shape_3_7_4096_rrs_01 7, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRS[0]",), ) @@ -471,9 +466,9 @@ def test_reshape_two_axis_decomposition_shape_3_7_4096_rrs_01_shape_3_7_64_64_rr 64, 64, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRS[0]R",), ) @@ -495,9 +490,9 @@ def test_reshape_two_axis_fusion_shape_3_7_4096_rrr_01_shape_21_4906_rr_01(self) 21, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RR",), ) @@ -519,9 +514,9 @@ def test_reshape_two_axis_fusion_shape_21_4096_rrr_01_shape_3_7_4906_rr_01(self) 7, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRR",), ) @@ -546,9 +541,9 @@ def test_reshape_two_axis_fusion_shape_3_64_7_64_rsrr_01_shape_192_7_64_srr_0101 7, 64, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]RR", "R"), - output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1, 0, 1])], output_shard_specs=("S[0]RR",), ) @@ -573,9 +568,9 @@ def test_reshape_two_axis_decomposition_shape_192_7_7_srr_010101_shape_3_64_7_7_ 7, 7, ), - input_device_meshs=[np.array([0, 1, 0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1, 0, 1])] * 2, input_shard_specs=("S[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]RR",), ) @@ -600,9 +595,9 @@ def test_reshape_two_axis_fusion_shape_3_64_7_7_rsrr_01_shape_192_7_7_srr_010101 7, 7, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]RR", "R"), - output_device_meshs=[np.array([0, 1, 0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1, 0, 1])], output_shard_specs=("S[0]RR",), ) @@ -627,9 +622,9 @@ def test_reshape_two_axis_decomposition_shape_192_7_64_srr_010101_shape_3_64_7_6 7, 64, ), - input_device_meshs=[np.array([0, 1, 0, 1, 0, 1])] * 2, + input_device_meshes=[np.array([0, 1, 0, 1, 0, 1])] * 2, input_shard_specs=("S[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]RR",), ) @@ -654,9 +649,9 @@ def test_reshape_two_axis_fusion_shape_3_7_64_64_rrsr_01_shape_3_7_4096_rrs_01(s 7, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]R", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RRS[0]",), ) @@ -678,9 +673,9 @@ def test_reshape_two_axis_fusion_shape_3_7_4096_rrs_01_shape_21_4906_rs_01(self) 21, 4096, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RRS[0]", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -690,29 +685,16 @@ def _check_distributed_expand( self, shape: Tuple[int, ...], target_shape: Tuple[int, ...], - input_device_meshs: np.ndarray, + input_device_meshes: np.ndarray, input_shard_specs: Tuple[str, ...], - output_device_meshs: np.ndarray, + output_device_meshes: np.ndarray, output_shard_specs: Tuple[str, ...], ): - assert all(len(mesh.shape) == 1 for mesh in input_device_meshs) - assert all(len(mesh.shape) == 1 for mesh in output_device_meshs) - assert len(input_device_meshs) == len(input_shard_specs) - assert len(output_device_meshs) == len(output_shard_specs) - - input_device_mesh_shapes = [] - input_device_mesh_elements = [] - for device_mesh in input_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - input_device_mesh_shapes.append(device_mesh_shape) - input_device_mesh_elements.append(device_mesh_element) - - output_device_mesh_shapes = [] - output_device_mesh_elements = [] - for device_mesh in output_device_meshs: - device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) - output_device_mesh_shapes.append(device_mesh_shape) - output_device_mesh_elements.append(device_mesh_element) + assert len(input_device_meshes) == len(input_shard_specs) + assert len(output_device_meshes) == len(output_shard_specs) + + input_device_mesh_shapes, input_device_mesh_elements = translate_all_device_meshes(input_device_meshes) + output_device_mesh_shapes, output_device_mesh_elements = translate_all_device_meshes(output_device_meshes) @onnxscript.script() def distributed_expand_instance(data_tensor: FLOAT, shape_tensor: INT64): @@ -734,11 +716,11 @@ def distributed_expand_instance(data_tensor: FLOAT, shape_tensor: INT64): dtype=np.int64, ) - local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshs[0]) + local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshes[0]) assert "S" not in input_shard_specs[1], "Shape should not be sharded." expected = data_tensor * np.ones(shape_tensor) - local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshs[0]) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshes[0]) onnx_model = distributed_expand_instance.to_model_proto( input_types=[FLOAT[tuple(local_data_tensor.shape)], INT64[tuple(shape_tensor.shape)]], @@ -780,9 +762,9 @@ def test_expand_sharded_on_expanded_axis(self): 8, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]",), ) @@ -799,9 +781,9 @@ def test_expand_sharded_on_expanded_axis_with_device_mesh_0101(self): 8, 8, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1, 0, 1])], + output_device_meshes=[np.array([0, 1, 0, 1])], output_shard_specs=("RS[0]",), ) @@ -818,9 +800,9 @@ def test_expand_replicated_on_expanded_axis(self): 1, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RR",), ) @@ -837,12 +819,12 @@ def test_expand_with_pass_through_sharding_spec(self): 1, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=( "S[0]R", "R", ), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("S[0]R",), ) @@ -863,13 +845,160 @@ def test_expand_in_tiny_llama(self): 256, 4, ), - input_device_meshs=[np.array([0, 1])] * 2, + input_device_meshes=[np.array([0, 1])] * 2, input_shard_specs=("RS[0]RR", "R"), - output_device_meshs=[np.array([0, 1])], + output_device_meshes=[np.array([0, 1])], output_shard_specs=("RS[0]RR",), ) +class TestDistributedReduce(unittest.TestCase): + def _check_distributed_reduce( + self, + keepdims: int, + dtype: np.dtype, + shape: Tuple[int, ...], + axes: Tuple[int, ...], + input_device_meshes: np.ndarray, + input_shard_specs: Tuple[str, ...], + output_device_meshes: np.ndarray, + output_shard_specs: Tuple[str, ...], + ): + assert len(input_device_meshes) == len(input_shard_specs) + assert len(output_device_meshes) == len(output_shard_specs) + + input_device_mesh_shapes, input_device_mesh_elements = translate_all_device_meshes(input_device_meshes) + output_device_mesh_shapes, output_device_mesh_elements = translate_all_device_meshes(output_device_meshes) + + @onnxscript.script() + def distributed_reduce_sum_instance(data_tensor: FLOAT, axes_tensor: INT64): + return MICROSOFT_OPSET.DistributedReduceSum( + data_tensor, + axes_tensor, + keepdims=keepdims, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + @onnxscript.script() + def distributed_reduce_max_instance(data_tensor: FLOAT, axes_tensor: INT64): + return MICROSOFT_OPSET.DistributedReduceMax( + data_tensor, + axes_tensor, + keepdims=keepdims, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + @onnxscript.script() + def distributed_reduce_mean_instance(data_tensor: FLOAT, axes_tensor: INT64): + return MICROSOFT_OPSET.DistributedReduceMean( + data_tensor, + axes_tensor, + keepdims=keepdims, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + rank = comm.Get_rank() + + for onnx_func, np_func in zip( + [distributed_reduce_sum_instance, distributed_reduce_max_instance, distributed_reduce_mean_instance], + [np.sum, np.maximum.reduce, np.mean], + ): + data = np.random.randint(4, size=shape).astype(dtype) + expected = np_func(data, axis=axes, keepdims=bool(keepdims)) + + assert len(input_shard_specs) == 2 and len(input_device_meshes) == 2, "Reduce has two inputs." + assert "S" not in input_shard_specs[1], "Tensor `axes` should not be sharded." + assert len(output_shard_specs) == 1 and len(output_device_meshes) == 1, "Reduce has only one output." + + local_data = shard_tensor_per_spec(data, rank, input_shard_specs[0], input_device_meshes[0]) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshes[0]) + + if dtype == np.float32: + onnx_model = onnx_func.to_model_proto( + input_types=[FLOAT[tuple(local_data.shape)], INT64[len(axes)]], + output_types=[FLOAT[tuple(local_expected.shape)]], + ) + elif dtype == np.int64: + onnx_model = onnx_func.to_model_proto( + input_types=[INT64[tuple(local_data.shape)], INT64[len(axes)]], + output_types=[INT64[tuple(local_expected.shape)]], + ) + elif dtype == np.float16: + onnx_model = onnx_func.to_model_proto( + input_types=[FLOAT16[tuple(local_data.shape)], INT64[len(axes)]], + output_types=[FLOAT16[tuple(local_expected.shape)]], + ) + else: + raise RuntimeError(f"Unsupported dtype: {dtype}") + + # Each MPI process owns a sharded model. + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + # Each MPI process executes its sharded model. + # The result is `local` tensor stored on a specific MPI rank + # instead of `logical` tensor. + result = sess.run( + None, + { + "data_tensor": local_data, + "axes_tensor": np.array(axes, dtype=np.int64), + }, + ) + + # Compare local tensor and the corresponding logical sub-tensor + # obtained by sharding logical tensor following output's sharding spec. + np.testing.assert_allclose(result[0], local_expected, rtol=1e-5, atol=1e-8) + + def test_reduce(self): + self._check_distributed_reduce( + keepdims=1, + dtype=np.float32, + shape=( + 8, + 4, + ), + axes=(0,), + input_device_meshes=[np.array([0, 1])] * 2, + input_shard_specs=("RR", "R"), + output_device_meshes=[np.array([0, 1])], + output_shard_specs=("RR",), + ) + + def test_reduce_sharded(self): + self._check_distributed_reduce( + keepdims=1, + dtype=np.float32, + shape=( + 8, + 4, + ), + axes=(1,), + input_device_meshes=[np.array([0, 1])] * 2, + input_shard_specs=("S[0]R", "R"), + output_device_meshes=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + class TestDistributed(unittest.TestCase): def test_matmul_rs_sr_rr(self): # It means 1-D tensor with single element: [2].