Skip to content

Commit

Permalink
distributed slice (#17761)
Browse files Browse the repository at this point in the history
### Description
Support DistributedSlice kernel in Cuda EP.

mainly support following cases:
1. input data is sharded or replica for all axes (including slice axes)
2. slice axes is sharded across different devices.

starts / ends / steps sharded across different devices are not supported
yet.

---------

Co-authored-by: Wei-Sheng Chin <[email protected]>
Co-authored-by: Cheng Tang <[email protected]@orttrainingdev9.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Cheng Tang <[email protected]>
  • Loading branch information
4 people authored Oct 12, 2023
1 parent 3f3ece4 commit ca8cab2
Show file tree
Hide file tree
Showing 11 changed files with 449 additions and 29 deletions.
1 change: 1 addition & 0 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding_spec.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc"
)
endif()
# add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio
Expand Down
16 changes: 1 addition & 15 deletions onnxruntime/contrib_ops/cuda/collective/distributed_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
// Distributed computation.
#include "sharding.h"
#include "distributed_matmul.h"
#include "nccl_kernels.h"
#include "mpi_include.h"

// ORT system.
Expand Down Expand Up @@ -63,20 +62,7 @@ static TensorShape InferMatmulOutputShape(
};

template <typename T>
DistributedMatMul<T>::DistributedMatMul(const OpKernelInfo& info) : NcclKernel(info) {
std::vector<int64_t> device_mesh_elements = info.GetAttrsOrDefault<int64_t>("device_mesh_elements");
std::vector<int64_t> device_mesh_shape = info.GetAttrsOrDefault<int64_t>("device_mesh_shape");
std::vector<std::string> input_shard_specs = info.GetAttrsOrDefault<std::string>("input_shard_specs");
std::vector<std::string> output_shard_specs = info.GetAttrsOrDefault<std::string>("output_shard_specs");

for (size_t i = 0; i < input_shard_specs.size(); ++i) {
auto spec = CreateTensorPartitionSpec(input_shard_specs[i], device_mesh_shape, device_mesh_elements);
input_shard_specs_.push_back(spec);
}
for (size_t i = 0; i < output_shard_specs.size(); ++i) {
auto spec = CreateTensorPartitionSpec(output_shard_specs[i], device_mesh_shape, device_mesh_elements);
output_shard_specs_.push_back(spec);
}
DistributedMatMul<T>::DistributedMatMul(const OpKernelInfo& info) : DistributedKernel(info) {
}

template <typename T>
Expand Down
10 changes: 2 additions & 8 deletions onnxruntime/contrib_ops/cuda/collective/distributed_matmul.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "sharding_spec.h"
#include "core/providers/cuda/cuda_kernel.h"
#include "sharding.h"

#include <algorithm>
#include <tuple>
Expand All @@ -20,15 +18,11 @@ namespace cuda {
#if defined(ORT_USE_NCCL)

template <typename T>
class DistributedMatMul final : public NcclKernel {
class DistributedMatMul final : public DistributedKernel {
public:
explicit DistributedMatMul(const OpKernelInfo& info);

Status ComputeInternal(OpKernelContext* context) const override;

private:
std::vector<TensorPartitionSpec> input_shard_specs_;
std::vector<TensorPartitionSpec> output_shard_specs_;
};

#endif
Expand Down
181 changes: 181 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/distributed_slice.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// Distributed computation.
#include "distributed_slice.h"
#include "mpi_include.h"

// ORT system.
#include "core/providers/cpu/tensor/slice.h"
#include "core/providers/cuda/tensor/slice.h"
#include "core/providers/cuda/math/matmul.h"
#include "core/providers/cuda/tensor/transpose.h"
#include "core/providers/cuda/cuda_check_memory.h"

namespace onnxruntime {
namespace contrib {
namespace cuda {

#if defined(ORT_USE_NCCL)
template <typename T, typename Tind>
DistributedSlice<T, Tind>::DistributedSlice(const OpKernelInfo& info) : DistributedKernel(info) {
}

template <typename T, typename Tind>
Status DistributedSlice<T, Tind>::ComputeInternal(OpKernelContext* context) const {
const auto tensor_shard_data = context->Input<Tensor>(0);
const auto tensor_shard_starts = context->Input<Tensor>(1);
const auto tensor_shard_ends = context->Input<Tensor>(2);

const TensorPartitionSpec& spec_data = input_shard_specs_[0];
const TensorPartitionSpec& spec_starts = input_shard_specs_[1];
const TensorPartitionSpec& spec_ends = input_shard_specs_[2];
const TensorPartitionSpec& spec_Y = output_shard_specs_[0];

const auto tensor_shard_axes = context->Input<Tensor>(3);
const TensorPartitionSpec& spec_axes = input_shard_specs_[3];

if (spec_starts.HasShard() ||
spec_ends.HasShard() ||
spec_axes.HasShard() ||
(input_shard_specs_.size() > 4 && input_shard_specs_[4].HasShard()))
ORT_THROW("DistributedSlice: shard on starts / ends / axes / steps are not supported yet.");

std::vector<int64_t> input_starts;
std::vector<int64_t> input_ends;
auto starts_data = tensor_shard_starts->DataAsSpan<Tind>();
input_starts.resize(starts_data.size());
std::copy(starts_data.begin(), starts_data.end(), input_starts.begin());
auto ends_data = tensor_shard_ends->DataAsSpan<Tind>();
input_ends.resize(ends_data.size());
std::copy(ends_data.begin(), ends_data.end(), input_ends.begin());

std::vector<int64_t> input_axes;
if (tensor_shard_axes) {
auto axes_data = tensor_shard_axes->DataAsSpan<Tind>();
input_axes.resize(axes_data.size());
std::copy(axes_data.begin(), axes_data.end(), input_axes.begin());
}

std::vector<int64_t> input_steps;
const auto tensor_shard_steps = context->Input<Tensor>(4);
if (tensor_shard_steps) {
const TensorPartitionSpec& spec_steps = input_shard_specs_[4];
if (spec_steps.HasShard())
ORT_THROW("Not supported yet.");

auto steps_data = tensor_shard_steps->DataAsSpan<Tind>();
input_steps.resize(steps_data.size());
std::copy(steps_data.begin(), steps_data.end(), input_steps.begin());
}

if (spec_data.GetPartitionAxis() != -1 &&
std::find(input_axes.begin(), input_axes.end(), spec_data.GetPartitionAxis()) != input_axes.end()) {
// shard on slice axes, reshard first
auto tmp_spec_data = TensorPartitionSpec::CreateAllReplica(spec_data);
auto tensor_data = ReshardTensor(this, context, spec_data, tmp_spec_data, nccl_->Rank(), tensor_shard_data);

const auto& input_shape = tensor_data->Shape();
const auto input_dimensions = input_shape.GetDims();
if (input_dimensions.empty()) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Cannot slice scalars");

SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions);
ORT_RETURN_IF_ERROR(SliceBase::PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata));
TensorShape output_shape(compute_metadata.output_dims_);

if (spec_Y.HasNoShard()) {
ORT_RETURN_IF_ERROR(FuncSlice(this,
context,
tensor_data.get(),
input_starts,
input_ends,
input_axes,
input_steps,
context->Output(0, output_shape)));
} else {
AllocatorPtr alloc;
ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc) == Status::OK());
auto dst_tensor = Tensor::Create(tensor_data->DataType(), output_shape, alloc);
ORT_RETURN_IF_ERROR(FuncSlice(this,
context,
tensor_data.get(),
input_starts,
input_ends,
input_axes,
input_steps,
dst_tensor.get()));
auto tmp_spec_output = TensorPartitionSpec::CreateAllReplica(spec_Y);
ReshardTensor(this, context, tmp_spec_output, spec_Y, nccl_->Rank(), dst_tensor.get(), 0);
}
} else {
const auto& input_shape = tensor_shard_data->Shape();
const auto input_dimensions = input_shape.GetDims();
if (input_dimensions.empty()) return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Cannot slice scalars");

SliceOp::PrepareForComputeMetadata compute_metadata(input_dimensions);
ORT_RETURN_IF_ERROR(SliceBase::PrepareForCompute(input_starts, input_ends, input_axes, input_steps, compute_metadata));
TensorShape output_shape(compute_metadata.output_dims_);

if (spec_Y.GetPartitionAxis() == spec_data.GetPartitionAxis()) {
ORT_RETURN_IF_ERROR(FuncSlice(this,
context,
tensor_shard_data,
input_starts,
input_ends,
input_axes,
input_steps,
context->Output(0, output_shape)));
} else {
AllocatorPtr alloc;
ORT_ENFORCE(context->GetTempSpaceAllocator(&alloc) == Status::OK());
auto dst_tensor = Tensor::Create(tensor_shard_data->DataType(), output_shape, alloc);
ORT_RETURN_IF_ERROR(FuncSlice(this,
context,
tensor_shard_data,
input_starts,
input_ends,
input_axes,
input_steps,
dst_tensor.get()));
ReshardTensor(this, context, spec_data, spec_Y, nccl_->Rank(), dst_tensor.get(), 0);
}
}

return Status::OK();
}

ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedSlice,
kMSDomain,
1,
float,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.InputMemoryType(OrtMemTypeCPUInput, 2)
.InputMemoryType(OrtMemTypeCPUInput, 3)
.InputMemoryType(OrtMemTypeCPUInput, 4)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<int64_t>()),
DistributedSlice<float, int64_t>);

ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedSlice,
kMSDomain,
1,
MLFloat16,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 1)
.InputMemoryType(OrtMemTypeCPUInput, 2)
.InputMemoryType(OrtMemTypeCPUInput, 3)
.InputMemoryType(OrtMemTypeCPUInput, 4)
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>())
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<int64_t>()),
DistributedSlice<MLFloat16, int64_t>);

#endif

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
32 changes: 32 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/distributed_slice.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once

#include <algorithm>
#include <tuple>
#include <optional>
#include <string>
#include <nccl.h>
#include <sstream>

#include "sharding.h"

namespace onnxruntime {
namespace contrib {
namespace cuda {

#if defined(ORT_USE_NCCL)

template <typename T, typename Tind>
class DistributedSlice final : public DistributedKernel {
public:
explicit DistributedSlice(const OpKernelInfo& info);

Status ComputeInternal(OpKernelContext* context) const override;
};

#endif

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
40 changes: 40 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,46 @@ std::unique_ptr<Tensor> ReshardTensor(
return dst;
}

void ReshardTensor(
const NcclKernel* nccl_kernel,
OpKernelContext* ctx,
const TensorPartitionSpec& src_spec,
const TensorPartitionSpec& dst_spec,
const int64_t device_id,
const Tensor* src,
int output_idx) {
// Implement ReshardTensor but returning a unique_ptr to Tensor instead.
const auto origin_shape = ComputeOriginShape(src->Shape(), src_spec);
const auto dst_shape = ComputeShardShape(origin_shape, dst_spec);
ORT_ENFORCE(CanShard(origin_shape, dst_spec), "Cannot shard tensor. Shape:", origin_shape, ", sharding spec: ", dst_spec.ToString());

auto* dst = ctx->Output(output_idx, dst_shape);
ReshardTensor(
nccl_kernel,
ctx,
src_spec,
dst_spec,
device_id,
src,
dst);
}

DistributedKernel::DistributedKernel(const OpKernelInfo& info) : NcclKernel(info) {
std::vector<int64_t> device_mesh_elements = info.GetAttrsOrDefault<int64_t>("device_mesh_elements");
std::vector<int64_t> device_mesh_shape = info.GetAttrsOrDefault<int64_t>("device_mesh_shape");
std::vector<std::string> input_shard_specs = info.GetAttrsOrDefault<std::string>("input_shard_specs");
std::vector<std::string> output_shard_specs = info.GetAttrsOrDefault<std::string>("output_shard_specs");

for (size_t i = 0; i < input_shard_specs.size(); ++i) {
auto spec = CreateTensorPartitionSpec(input_shard_specs[i], device_mesh_shape, device_mesh_elements);
input_shard_specs_.push_back(spec);
}
for (size_t i = 0; i < output_shard_specs.size(); ++i) {
auto spec = CreateTensorPartitionSpec(output_shard_specs[i], device_mesh_shape, device_mesh_elements);
output_shard_specs_.push_back(spec);
}
}

#endif

} // namespace cuda
Expand Down
24 changes: 22 additions & 2 deletions onnxruntime/contrib_ops/cuda/collective/sharding.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once

#include "sharding_spec.h"
#include "nccl_kernels.h"

#pragma once

namespace onnxruntime {
namespace contrib {
namespace cuda {
Expand Down Expand Up @@ -49,6 +48,16 @@ void ReshardTensor(
const Tensor* src,
Tensor* dst);

// Output from ctx
void ReshardTensor(
const NcclKernel* nccl_kernel,
OpKernelContext* ctx,
const TensorPartitionSpec& src_spec,
const TensorPartitionSpec& dst_spec,
const int64_t device_id,
const Tensor* src,
int output_idx);

std::unique_ptr<Tensor> ReshardTensor(
const NcclKernel* nccl_kernel,
OpKernelContext* ctx,
Expand All @@ -57,6 +66,17 @@ std::unique_ptr<Tensor> ReshardTensor(
const int64_t device_id,
const Tensor* src);

class TensorPartitionSpec;

class DistributedKernel : public NcclKernel {
public:
explicit DistributedKernel(const OpKernelInfo& info);

protected:
std::vector<TensorPartitionSpec> input_shard_specs_;
std::vector<TensorPartitionSpec> output_shard_specs_;
};

#endif

} // namespace cuda
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/contrib_ops/cuda/collective/sharding_spec.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once

#include "core/common/common.h"
#include "core/framework/tensor_shape.h"
Expand All @@ -8,8 +9,6 @@
#include <sstream>
#include <vector>

#pragma once

namespace onnxruntime {
namespace contrib {
namespace cuda {
Expand Down
Loading

0 comments on commit ca8cab2

Please sign in to comment.