Skip to content

Commit

Permalink
Distributed Expand (microsoft#18126)
Browse files Browse the repository at this point in the history
This PR implements DistributedExpand for llama 2.

Representative Examples of DistributedExpand:
- [shard on non-expanded axis] `input tensor (shape=[8, 1], spec=S[0]R,
device_mesh=[0,1]) -> Expand(target_shape=[8, 2] -> output tensor
(shape=[8, 2], spec=S[0]R, device_mesh=[0,1])`
- [sharding expanded axis is invalid since it must have dim=1 and axis
with dim=1 cannot be sharded] `input tensor (shape=[1, 8], spec=S[0]R,
device_mesh=[0,1]) -> Expand(target_shape=[2, 8] -> output tensor
(shape=[2, 8], spec=S[0]R, device_mesh=[0,1])`

From those examples, we observe a few important behaviors.

- The output sharding spec is always the same to the input sharding
spec.
- Expanding always happen on axis with dimension=1. Otherwise, it will
violate the broadcasting rule.
- No communication is needed since all computation can happen locally.
Let's consider the first example again. If you put the first half tensor
(shape: [4, 1]) on device 0 and the second half (shape: [4, 1]) on
device 1, then `Expand` it with target shape [4, 2] , these two local
tensors (shape: [4, 2]) are exactly the same as the one described by
output sharding spec.

Algorithm:
- Compute logical (i.e., unsharded) shapes of input and output.
- Compute sharded output shape from logical output.
- Call Expand to broadcast local input to sharded output shape.

How to review?
- Start with [changes in
onnxruntime_test_distributed.py](microsoft@ea33392).
Those tests are good examples for using this op.
- [Read
expand.h/expand.cc](microsoft@e4c4998).
Theose changes are for exposing functionalities in Expand to
DistributedExpand.
- Read distributed_expand.h/distributed_expand.cc. It follows the
algorithm described above. The commit
microsoft@68ac301
first sketches the definition of DistributedExpand. The next commit
microsoft@0eb9330
adds real implementation.
  • Loading branch information
wschin authored Oct 28, 2023
1 parent 8daabf3 commit 24f9c1a
Show file tree
Hide file tree
Showing 9 changed files with 470 additions and 0 deletions.
1 change: 1 addition & 0 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reshape.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_expand.cc"
)
endif()
# add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio
Expand Down
1 change: 1 addition & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ if (NOT onnxruntime_USE_NCCL)
list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc")
endif()

set(provider_excluded_files
Expand Down
110 changes: 110 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// Distributed computation.
#include "distributed_expand.h"
#include "sharding.h"
#include "sharding_spec.h"
#include "nccl_kernels.h"
#include "mpi_include.h"

// ORT system.
#include "core/providers/cuda/tensor/expand.h"

// std C++.
#include <iostream>

namespace onnxruntime {
namespace contrib {
namespace cuda {

#if defined(ORT_USE_NCCL)

template <typename T>
DistributedExpand<T>::DistributedExpand(const OpKernelInfo& info) : DistributedKernel(info) {}

template <typename T>
Status DistributedExpand<T>::ComputeInternal(OpKernelContext* context) const {
ORT_ENFORCE(context != nullptr);
// Assumptions.
// - Shape is not sharded.
// Algorithm.
// - Compute logical output shape.
// - Compute local output shape.
// - Expand from local input to local output.

auto input_tensor = context->Input<Tensor>(0);
auto shape_tensor = context->Input<Tensor>(1);
const auto& input_sharding_spec = input_shard_specs_.at(0);
const auto& shape_sharding_spec = input_shard_specs_.at(1);
const auto& output_sharding_spec = output_shard_specs_.at(0);

ORT_ENFORCE(shape_sharding_spec.HasNoShard(),
"It's not worth to shard Shape tensor. "
"If sharding shape is needed, please submit a feature request.");
// Compute logical input shape.
const auto original_input_shape = ComputeOriginShape(input_tensor->Shape(), input_sharding_spec);

// Compute logical output shape.
// This `shape_tensor` stores the logical output shape.
const auto* p_shape = shape_tensor->Data<int64_t>();
TensorShapeVector original_output_dims{p_shape, p_shape + shape_tensor->Shape().Size()};
TensorShape original_output_shape(original_output_dims);
ORT_ENFORCE(
onnxruntime::cuda::ComputeOutputShape(
Node().Name(),
original_input_shape,
original_output_dims, original_output_shape)
.IsOK());

// Compute local output shape.
const auto local_output_shape = ComputeShardShape(original_output_shape, output_sharding_spec);

auto output_tensor = context->Output(0, local_output_shape);

return FuncExpand(
this,
context,
input_tensor,
shape_tensor,
output_tensor);
}

ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedExpand,
kMSDomain,
1,
int64_t,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedExpand<int64_t>);

ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedExpand,
kMSDomain,
1,
float,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedExpand<float>);

ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedExpand,
kMSDomain,
1,
MLFloat16,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedExpand<MLFloat16>);

#endif

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
35 changes: 35 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/distributed_expand.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "sharding_spec.h"
#include "sharding.h"
#include "core/providers/cuda/cuda_kernel.h"

#include <algorithm>
#include <tuple>
#include <optional>
#include <string>
#include <nccl.h>
#include <sstream>

#pragma once

namespace onnxruntime {
namespace contrib {
namespace cuda {

#if defined(ORT_USE_NCCL)

template <typename T>
class DistributedExpand final : public DistributedKernel {
public:
explicit DistributedExpand(const OpKernelInfo& info);

Status ComputeInternal(OpKernelContext* context) const override;
};

#endif

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
8 changes: 8 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReshape);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReshape);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReshape);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand);
#endif

template <>
Expand Down Expand Up @@ -344,6 +348,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReshape)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand)>,
#endif

};
Expand Down
37 changes: 37 additions & 0 deletions onnxruntime/core/graph/contrib_ops/collective_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,43 @@ void RegisterCollectiveOps() {
OpSchema::NonDifferentiable)
.Output(0, "reshaped", "Reshaped data.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types.");

ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedExpand)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Attr("input_device_mesh_elements",
"device_mesh_elements[i] defines the device mesh's value for the i-th input. "
"E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd "
" inputs are stored on the 0-th and the 1st devices, respectively.",
AttributeProto::STRINGS)
.Attr("input_device_mesh_shapes",
"device_mesh_shape[i] defines the device mesh's shape for the i-th input.",
AttributeProto::STRINGS)
.Attr("input_shard_specs",
"The sharding spec of inputs. "
"E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.",
AttributeProto::STRINGS)
.Attr("output_device_mesh_elements",
"Similar to input_device_mesh_elments but for outputs.",
AttributeProto::STRINGS)
.Attr("output_device_mesh_shapes",
"Similar to input_device_mesh_shapes but for outputs.",
AttributeProto::STRINGS)
.Attr("output_shard_specs",
"Similar to input_shard_specs but for outputs.",
AttributeProto::STRINGS)
.Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Input(
1,
"shape",
"A 1-D tensor indicates the shape you want to expand to, following the broadcast rule",
"tensor(int64)",
OpSchema::Single,
true,
1,
OpSchema::NonDifferentiable)
.Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors.");
}

} // namespace contrib
Expand Down
80 changes: 80 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/expand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,86 @@ Status Expand::ComputeInternal(OpKernelContext* ctx) const {
input_strides);
}

Status FuncExpand(
const CudaKernel* cuda_kernel,
OpKernelContext* ctx,
const Tensor* input_data_tensor,
const Tensor* /*input_shape_tensor*/,
Tensor* output_tensor) {
TensorShape output_shape = output_tensor->Shape();

#ifdef ENABLE_STRIDED_TENSORS
// Strided output.
if (input_data_tensor->DataRaw() == output_tensor->DataRaw()) {
gsl::span<const int64_t> input_strides = input_data_tensor->Strides();
TensorShapeVector output_strides =
ComputeOutputStrides(input_data_tensor->Shape(), input_strides, output_shape);
output_tensor->SetShapeAndStrides(output_shape, output_strides);
return Status::OK();
}
#endif

auto output_dims = output_shape.AsShapeVector();
auto input_dims = input_data_tensor->Shape().AsShapeVector();

CalcEffectiveDims(input_dims, output_dims);
int rank = gsl::narrow_cast<int>(output_dims.size());

TensorPitches original_input_strides(input_dims);
TensorPitches original_output_strides(output_dims);

TArray<int64_t> input_strides(rank);
for (auto i = 0; i < rank; i++) {
input_strides[i] = input_dims[i] == 1 ? 0 : original_input_strides[i];
}

TArray<fast_divmod> output_strides(rank);
for (auto i = 0; i < rank; i++) {
output_strides[i] = fast_divmod(static_cast<int>(original_output_strides[i]));
}

return ExpandImpl(
cuda_kernel->Stream(ctx),
input_data_tensor->DataType()->Size(),
gsl::narrow_cast<int>(output_shape.Size()),
gsl::narrow_cast<int>(input_data_tensor->Shape().Size()),
input_data_tensor->DataRaw(),
output_tensor->MutableDataRaw(),
output_strides,
input_strides);
}

std::unique_ptr<Tensor> FuncExpand(
const CudaKernel* cuda_kernel,
OpKernelContext* ctx,
const Tensor* input_data_tensor,
const Tensor* input_shape_tensor) {
// new shape to be expanded to
const auto* p_shape = input_shape_tensor->Data<int64_t>();
TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor->Shape().Size()};
TensorShape output_shape(output_dims);

ORT_ENFORCE(
ComputeOutputShape(
cuda_kernel->Node().Name(),
input_data_tensor->Shape(),
output_dims, output_shape)
.IsOK());

// Pre-allocate output.
AllocatorPtr alloc;
ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc).IsOK());
auto output_tensor = Tensor::Create(input_data_tensor->DataType(), output_shape, alloc);

// Only assign output values when output tensor is non-empty
// because empty tensor doesn't own any data.
if (output_shape.Size() > 0) {
ORT_ENFORCE(FuncExpand(cuda_kernel, ctx, input_data_tensor, input_shape_tensor, output_tensor.get()).IsOK());
}

return output_tensor;
}

#ifdef ENABLE_STRIDED_TENSORS
#define CREATE_EXPAND_KERNEL_DEF (*KernelDefBuilder::Create()).MayStridedOutput(0, 0)
#else
Expand Down
13 changes: 13 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/expand.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,18 @@ Status ComputeOutputShape(
const TensorShape& rhs_shape,
TensorShape& out_shape);

Status FuncExpand(
const CudaKernel* cuda_kernel,
OpKernelContext* ctx,
const Tensor* input_data_tensor,
const Tensor* /*input_shape_tensor*/,
Tensor* output_tensor);

std::unique_ptr<Tensor> FuncExpand(
const CudaKernel* cuda_kernel,
OpKernelContext* ctx,
const Tensor* input_data_tensor,
const Tensor* input_shape_tensor);

} // namespace cuda
} // namespace onnxruntime
Loading

0 comments on commit 24f9c1a

Please sign in to comment.