From e4c49987f5a09e19527248adcc197b7d4a695636 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 26 Oct 2023 10:39:05 -0700 Subject: [PATCH 1/5] Expose Expand --- .../core/providers/cuda/tensor/expand.cc | 80 +++++++++++++++++++ .../core/providers/cuda/tensor/expand.h | 13 +++ 2 files changed, 93 insertions(+) diff --git a/onnxruntime/core/providers/cuda/tensor/expand.cc b/onnxruntime/core/providers/cuda/tensor/expand.cc index e9634df205842..368c167f58641 100644 --- a/onnxruntime/core/providers/cuda/tensor/expand.cc +++ b/onnxruntime/core/providers/cuda/tensor/expand.cc @@ -142,6 +142,86 @@ Status Expand::ComputeInternal(OpKernelContext* ctx) const { input_strides); } +Status FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* /*input_shape_tensor*/, + Tensor* output_tensor) { + + TensorShape output_shape = output_tensor->Shape(); + +#ifdef ENABLE_STRIDED_TENSORS + // Strided output. + if (input_data_tensor->DataRaw() == output_tensor->DataRaw()) { + gsl::span input_strides = input_data_tensor->Strides(); + TensorShapeVector output_strides = + ComputeOutputStrides(input_data_tensor->Shape(), input_strides, output_shape); + output_tensor->SetShapeAndStrides(output_shape, output_strides); + return Status::OK(); + } +#endif + + auto output_dims = output_shape.AsShapeVector(); + auto input_dims = input_data_tensor->Shape().AsShapeVector(); + + CalcEffectiveDims(input_dims, output_dims); + int rank = gsl::narrow_cast(output_dims.size()); + + TensorPitches original_input_strides(input_dims); + TensorPitches original_output_strides(output_dims); + + TArray input_strides(rank); + for (auto i = 0; i < rank; i++) { + input_strides[i] = input_dims[i] == 1 ? 0 : original_input_strides[i]; + } + + TArray output_strides(rank); + for (auto i = 0; i < rank; i++) { + output_strides[i] = fast_divmod(static_cast(original_output_strides[i])); + } + + return ExpandImpl( + cuda_kernel->Stream(ctx), + input_data_tensor->DataType()->Size(), + gsl::narrow_cast(output_shape.Size()), + gsl::narrow_cast(input_data_tensor->Shape().Size()), + input_data_tensor->DataRaw(), + output_tensor->MutableDataRaw(), + output_strides, + input_strides); +} + +std::unique_ptr FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* input_shape_tensor) { + // new shape to be expanded to + const auto* p_shape = input_shape_tensor->Data(); + TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor->Shape().Size()}; + TensorShape output_shape(output_dims); + + ORT_ENFORCE( + ComputeOutputShape( + cuda_kernel->Node().Name(), + input_data_tensor->Shape(), + output_dims, output_shape).IsOK()); + + // Pre-allocate output. + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc).IsOK()); + auto output_tensor = Tensor::Create(input_data_tensor->DataType(), output_shape, alloc); + + // Only assign output values when output tensor is non-empty + // because empty tensor doesn't own any data. + if (output_shape.Size() > 0) { + ORT_ENFORCE(FuncExpand(cuda_kernel, ctx, input_data_tensor, input_shape_tensor, output_tensor.get()).IsOK()); + } + + return output_tensor; +} + #ifdef ENABLE_STRIDED_TENSORS #define CREATE_EXPAND_KERNEL_DEF (*KernelDefBuilder::Create()).MayStridedOutput(0, 0) #else diff --git a/onnxruntime/core/providers/cuda/tensor/expand.h b/onnxruntime/core/providers/cuda/tensor/expand.h index 4cf4c14e61058..a0b12790017f6 100644 --- a/onnxruntime/core/providers/cuda/tensor/expand.h +++ b/onnxruntime/core/providers/cuda/tensor/expand.h @@ -20,5 +20,18 @@ Status ComputeOutputShape( const TensorShape& rhs_shape, TensorShape& out_shape); +Status FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* /*input_shape_tensor*/, + Tensor* output_tensor); + +std::unique_ptr FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* input_shape_tensor); + } // namespace cuda } // namespace onnxruntime From ea33392f375afd8e95d29bd5b1a403192ed3bebc Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 26 Oct 2023 18:36:47 -0700 Subject: [PATCH 2/5] Add tests --- .../python/onnxruntime_test_distributed.py | 185 ++++++++++++++++++ 1 file changed, 185 insertions(+) diff --git a/onnxruntime/test/python/onnxruntime_test_distributed.py b/onnxruntime/test/python/onnxruntime_test_distributed.py index 2acca4a8f22ae..e0fb3979a9f55 100644 --- a/onnxruntime/test/python/onnxruntime_test_distributed.py +++ b/onnxruntime/test/python/onnxruntime_test_distributed.py @@ -685,6 +685,191 @@ def test_reshape_two_axis_fusion_shape_3_7_4096_rrs_01_shape_21_4906_rs_01(self) ) +class TestDistributedExpand(unittest.TestCase): + def _check_distributed_expand( + self, + shape: Tuple[int, ...], + target_shape: Tuple[int, ...], + input_device_meshs: np.ndarray, + input_shard_specs: Tuple[str, ...], + output_device_meshs: np.ndarray, + output_shard_specs: Tuple[str, ...], + ): + assert all(len(mesh.shape) == 1 for mesh in input_device_meshs) + assert all(len(mesh.shape) == 1 for mesh in output_device_meshs) + assert len(input_device_meshs) == len(input_shard_specs) + assert len(output_device_meshs) == len(output_shard_specs) + + input_device_mesh_shapes = [] + input_device_mesh_elements = [] + for device_mesh in input_device_meshs: + device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) + input_device_mesh_shapes.append(device_mesh_shape) + input_device_mesh_elements.append(device_mesh_element) + + output_device_mesh_shapes = [] + output_device_mesh_elements = [] + for device_mesh in output_device_meshs: + device_mesh_shape, device_mesh_element = translate_device_mesh_to_attrs(device_mesh) + output_device_mesh_shapes.append(device_mesh_shape) + output_device_mesh_elements.append(device_mesh_element) + + @onnxscript.script() + def distributed_expand_instance(data_tensor: FLOAT, shape_tensor: INT64): + return MICROSOFT_OPSET.DistributedExpand( + data_tensor, + shape_tensor, + input_device_mesh_shapes=input_device_mesh_shapes, + input_device_mesh_elements=input_device_mesh_elements, + input_shard_specs=input_shard_specs, + output_device_mesh_shapes=output_device_mesh_shapes, + output_device_mesh_elements=output_device_mesh_elements, + output_shard_specs=output_shard_specs, + ) + + rank = comm.Get_rank() + data_tensor = np.arange(np.prod(shape), dtype=np.float32).reshape(*shape) + shape_tensor = np.array( + target_shape, + dtype=np.int64, + ) + + local_data_tensor = shard_tensor_per_spec(data_tensor, rank, input_shard_specs[0], input_device_meshs[0]) + assert "S" not in input_shard_specs[1], "Shape should not be sharded." + + expected = data_tensor * np.ones(shape_tensor) + local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshs[0]) + + onnx_model = distributed_expand_instance.to_model_proto( + input_types=[FLOAT[tuple(local_data_tensor.shape)], INT64[tuple(shape_tensor.shape)]], + output_types=[FLOAT[tuple(local_expected.shape)]], + ) + + # Each MPI process owns a sharded model. + sess = ort.InferenceSession( + onnx_model.SerializeToString(), + providers=["CUDAExecutionProvider"], + provider_options=[{"device_id": str(rank)}], + ) + + # Each MPI process executes its sharded model. + # The result is `local` tensor stored on a specific MPI rank + # instead of `logical` tensor. + result = sess.run( + None, + { + "data_tensor": local_data_tensor, + "shape_tensor": shape_tensor, + }, + ) + + # Compare local tensor and the corresponding logical sub-tensor + # obtained by sharding logical tensor following output's sharding spec. + np.testing.assert_allclose(result[0], local_expected, rtol=1e-5, atol=1e-8) + + def test_expand_sharded_on_expanded_axis(self): + # data: shape=[8,1], spec=(RR, [0,1]) + # shape: shape=[2], spec=(R, [0,1]), value=[1,4] + # output: shape=[8,4], spec=(RS, [0,1]) + self._check_distributed_expand( + shape=( + 8, + 1, + ), + target_shape=( + 8, + 4, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]",), + ) + + def test_expand_sharded_on_expanded_axis_with_device_mesh_0101(self): + # data: shape=[8,1], spec=(RR, [0,1]) + # shape: shape=[2], spec=(R, [0,1]), value=[1,4] + # output: shape=[8,4], spec=(RS, [0,1]) + self._check_distributed_expand( + shape=( + 8, + 1, + ), + target_shape=( + 8, + 8, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RR", "R"), + output_device_meshs=[np.array([0, 1, 0, 1])], + output_shard_specs=("RS[0]",), + ) + + def test_expand_replicated_on_expanded_axis(self): + # data: shape=[8,1], spec=(RR, [0,1]) + # shape: shape=[2], spec=(R, [0,1]), value=[1,4] + # output: shape=[8,4], spec=(RR, [0,1]) + self._check_distributed_expand( + shape=( + 8, + 1, + ), + target_shape=( + 1, + 4, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RR",), + ) + + def test_expand_with_pass_through_sharding_spec(self): + # data: shape=[8,1], spec=(SR, [0,1]) + # shape: shape=[2], spec=(R, [0,1]), value=[1,4] + # output: shape=[8,4], spec=(SR, [0,1]) + self._check_distributed_expand( + shape=( + 8, + 1, + ), + target_shape=( + 1, + 4, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=( + "S[0]R", + "R", + ), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("S[0]R",), + ) + + def test_expand_in_tiny_llama(self): + # data: shape=[2,4,256,4], spec=(RSRR, [0,1]) + # shape: shape=[4], spec=(R, [0,1,2,3]), value=[2,4,256,4] + # output: shape=[2,4,256,4], spec=(RSRR, [0,1]) + self._check_distributed_expand( + shape=( + 2, + 4, + 256, + 4, + ), + target_shape=( + 2, + 4, + 256, + 4, + ), + input_device_meshs=[np.array([0, 1])] * 2, + input_shard_specs=("RS[0]RR", "R"), + output_device_meshs=[np.array([0, 1])], + output_shard_specs=("RS[0]RR",), + ) + + class TestDistributed(unittest.TestCase): def test_matmul_rs_sr_rr(self): # It means 1-D tensor with single element: [2]. From 68ac301bbaff44d08168ac9049161a4d428b3c3d Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 26 Oct 2023 11:42:59 -0700 Subject: [PATCH 3/5] Skeleton of DistributedExpand --- cmake/onnxruntime_providers_cuda.cmake | 1 + cmake/onnxruntime_rocm_hipify.cmake | 1 + .../cuda/collective/distributed_expand.cc | 69 +++++++++++++++++++ .../cuda/collective/distributed_expand.h | 35 ++++++++++ .../contrib_ops/cuda/cuda_contrib_kernels.cc | 8 +++ .../core/graph/contrib_ops/collective_defs.cc | 37 ++++++++++ 6 files changed, 151 insertions(+) create mode 100644 onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc create mode 100644 onnxruntime/contrib_ops/cuda/collective/distributed_expand.h diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 02b17ee324f4f..043789c36c327 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -39,6 +39,7 @@ "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc" "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reshape.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_expand.cc" ) endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index ec021a1550d6c..6ccf063c71290 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -108,6 +108,7 @@ if (NOT onnxruntime_USE_NCCL) list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc") list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc") + list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc") endif() set(provider_excluded_files diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc new file mode 100644 index 0000000000000..a946e8812d3ff --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Distributed computation. +#include "distributed_expand.h" +#include "sharding.h" +#include "sharding_spec.h" +#include "nccl_kernels.h" +#include "mpi_include.h" + +// ORT system. +#include "core/providers/cuda/tensor/expand.h" + +// std C++. +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +DistributedExpand::DistributedExpand(const OpKernelInfo& info) : DistributedKernel(info) {} + +template +Status DistributedExpand::ComputeInternal(OpKernelContext* context) const { + ORT_ENFORCE(context != nullptr); + return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported expand pattern."); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedExpand, + kMSDomain, + 1, + int64_t, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedExpand); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedExpand, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedExpand); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + DistributedExpand, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPUInput, 1), + DistributedExpand); + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.h b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.h new file mode 100644 index 0000000000000..dedb1bdc5aa36 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "sharding_spec.h" +#include "sharding.h" +#include "core/providers/cuda/cuda_kernel.h" + +#include +#include +#include +#include +#include +#include + +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#if defined(ORT_USE_NCCL) + +template +class DistributedExpand final : public DistributedKernel { + public: + explicit DistributedExpand(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +#endif + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index e6a216795c10b..2618fe4a238bd 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -170,6 +170,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReshape); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReshape); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReshape); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand); #endif template <> @@ -344,6 +348,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index 8082b8c010e91..070df487a264d 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -236,6 +236,43 @@ void RegisterCollectiveOps() { OpSchema::NonDifferentiable) .Output(0, "reshaped", "Reshaped data.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedExpand) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Attr("input_device_mesh_elements", + "device_mesh_elements[i] defines the device mesh's value for the i-th input. " + "E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd " + " inputs are stored on the 0-th and the 1st devices, respectively.", + AttributeProto::STRINGS) + .Attr("input_device_mesh_shapes", + "device_mesh_shape[i] defines the device mesh's shape for the i-th input.", + AttributeProto::STRINGS) + .Attr("input_shard_specs", + "The sharding spec of inputs. " + "E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_elements", + "Similar to input_device_mesh_elments but for outputs.", + AttributeProto::STRINGS) + .Attr("output_device_mesh_shapes", + "Similar to input_device_mesh_shapes but for outputs.", + AttributeProto::STRINGS) + .Attr("output_shard_specs", + "Similar to input_shard_specs but for outputs.", + AttributeProto::STRINGS) + .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Input( + 1, + "shape", + "A 1-D tensor indicates the shape you want to expand to, following the broadcast rule", + "tensor(int64)", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors."); } } // namespace contrib From 0eb9330c3ba836911932444caca7fec0cbdad222 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 26 Oct 2023 20:13:48 -0700 Subject: [PATCH 4/5] Implement details for d-expand Fix a function call --- .../cuda/collective/distributed_expand.cc | 42 ++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc index a946e8812d3ff..ec1826d1eabd2 100644 --- a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc @@ -26,7 +26,47 @@ DistributedExpand::DistributedExpand(const OpKernelInfo& info) : DistributedK template Status DistributedExpand::ComputeInternal(OpKernelContext* context) const { ORT_ENFORCE(context != nullptr); - return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported expand pattern."); + // Assumptions. + // - Shape is not sharded. + // Algorithm. + // - Compute logical output shape. + // - Compute local output shape. + // - Expand from local input to local output. + + auto input_tensor = context->Input(0); + auto shape_tensor = context->Input(1); + const auto& input_sharding_spec = input_shard_specs_.at(0); + const auto& shape_sharding_spec = input_shard_specs_.at(1); + const auto& output_sharding_spec = output_shard_specs_.at(0); + + ORT_ENFORCE(shape_sharding_spec.HasNoShard(), + "It's not worth to shard Shape tensor. " + "If sharding shape is needed, please submit a feature request."); + // Compute logical input shape. + const auto original_input_shape = ComputeOriginShape(input_tensor->Shape(), input_sharding_spec); + + // Compute logical output shape. + // This `shape_tensor` stores the logical output shape. + const auto* p_shape = shape_tensor->Data(); + TensorShapeVector original_output_dims{p_shape, p_shape + shape_tensor->Shape().Size()}; + TensorShape original_output_shape(original_output_dims); + ORT_ENFORCE( + onnxruntime::cuda::ComputeOutputShape( + Node().Name(), + original_input_shape, + original_output_dims, original_output_shape).IsOK()); + + // Compute local output shape. + const auto local_output_shape = ComputeShardShape(original_output_shape, output_sharding_spec); + + auto output_tensor = context->Output(0, local_output_shape); + + return FuncExpand( + this, + context, + input_tensor, + shape_tensor, + output_tensor); } ONNX_OPERATOR_TYPED_KERNEL_EX( From 128f87c1a3149533b8bad0ef87e37c9055a6a7c5 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 27 Oct 2023 11:10:28 -0700 Subject: [PATCH 5/5] lint --- .../cuda/collective/distributed_expand.cc | 19 ++++++++++--------- .../core/providers/cuda/tensor/expand.cc | 10 +++++----- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc index ec1826d1eabd2..3cfa3ab959343 100644 --- a/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc @@ -51,10 +51,11 @@ Status DistributedExpand::ComputeInternal(OpKernelContext* context) const { TensorShapeVector original_output_dims{p_shape, p_shape + shape_tensor->Shape().Size()}; TensorShape original_output_shape(original_output_dims); ORT_ENFORCE( - onnxruntime::cuda::ComputeOutputShape( - Node().Name(), - original_input_shape, - original_output_dims, original_output_shape).IsOK()); + onnxruntime::cuda::ComputeOutputShape( + Node().Name(), + original_input_shape, + original_output_dims, original_output_shape) + .IsOK()); // Compute local output shape. const auto local_output_shape = ComputeShardShape(original_output_shape, output_sharding_spec); @@ -62,11 +63,11 @@ Status DistributedExpand::ComputeInternal(OpKernelContext* context) const { auto output_tensor = context->Output(0, local_output_shape); return FuncExpand( - this, - context, - input_tensor, - shape_tensor, - output_tensor); + this, + context, + input_tensor, + shape_tensor, + output_tensor); } ONNX_OPERATOR_TYPED_KERNEL_EX( diff --git a/onnxruntime/core/providers/cuda/tensor/expand.cc b/onnxruntime/core/providers/cuda/tensor/expand.cc index 368c167f58641..806ecfa1aab17 100644 --- a/onnxruntime/core/providers/cuda/tensor/expand.cc +++ b/onnxruntime/core/providers/cuda/tensor/expand.cc @@ -148,7 +148,6 @@ Status FuncExpand( const Tensor* input_data_tensor, const Tensor* /*input_shape_tensor*/, Tensor* output_tensor) { - TensorShape output_shape = output_tensor->Shape(); #ifdef ENABLE_STRIDED_TENSORS @@ -203,10 +202,11 @@ std::unique_ptr FuncExpand( TensorShape output_shape(output_dims); ORT_ENFORCE( - ComputeOutputShape( - cuda_kernel->Node().Name(), - input_data_tensor->Shape(), - output_dims, output_shape).IsOK()); + ComputeOutputShape( + cuda_kernel->Node().Name(), + input_data_tensor->Shape(), + output_dims, output_shape) + .IsOK()); // Pre-allocate output. AllocatorPtr alloc;