Skip to content

Commit

Permalink
all-replica version
Browse files Browse the repository at this point in the history
  • Loading branch information
wschin committed Oct 31, 2023
1 parent 9e58239 commit c4cdf3e
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 50 deletions.
137 changes: 97 additions & 40 deletions onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mpi_include.h"

// ORT system.
#include "core/providers/cuda/cudnn_common.h"
#include "core/providers/cuda/reduction/reduction_ops.h"

// std C++.
Expand All @@ -22,44 +23,120 @@ namespace cuda {
#if defined(ORT_USE_NCCL)

template <typename T>
DistributedReduceSum<T>::DistributedReduceSum(const OpKernelInfo& info) : DistributedKernel(info) {};
DistributedReduceSum<T>::DistributedReduceSum(const OpKernelInfo& info) : DistributedKernel(info) {
keepdims_ = info.GetAttrOrDefault<int64_t>("keepdims", 1);
cudnn_reduce_op_ = CUDNN_REDUCE_TENSOR_ADD;
};

template <typename T>
DistributedReduceMean<T>::DistributedReduceMean(const OpKernelInfo& info) : DistributedKernel(info) {};
DistributedReduceMean<T>::DistributedReduceMean(const OpKernelInfo& info) : DistributedKernel(info) {
keepdims_ = info.GetAttrOrDefault<int64_t>("keepdims", 1);
cudnn_reduce_op_ = CUDNN_REDUCE_TENSOR_AVG;
};

template <typename T>
DistributedReduceMax<T>::DistributedReduceMax(const OpKernelInfo& info) : DistributedKernel(info) {};
DistributedReduceMax<T>::DistributedReduceMax(const OpKernelInfo& info) : DistributedKernel(info) {
keepdims_ = info.GetAttrOrDefault<int64_t>("keepdims", 1);
cudnn_reduce_op_ = CUDNN_REDUCE_TENSOR_MAX;
};

template <typename T>
Status DistributedReduceSum<T>::ComputeInternal(OpKernelContext* context) const {
ORT_ENFORCE(context != nullptr);
return Status::OK();
const Tensor* input_tensor = context->Input<Tensor>(0);
const Tensor* axes_tensor = context->Input<Tensor>(1);
ORT_ENFORCE(axes_tensor != nullptr, "Axes input cannot be null. Please check the 2nd input.");
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "Axes tensor must be an 1-D tensor.");
auto axes_span = axes_tensor->DataAsSpan<int64_t>();

// Case 1: empty axes means treating this reduction as a no-op.
if (axes_span.empty()) {
auto* output_tensor = context->Output(0, input_tensor->Shape());
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor->MutableData<T>(), input_tensor->Data<T>(), input_tensor->SizeInBytes(),
cudaMemcpyDeviceToDevice, Stream(context)));
return Status::OK();
}

// Case 2: this is a valid reduction. Let's prepare for it.
onnxruntime::cuda::PrepareReduceMetadata metadata;
ORT_RETURN_IF_ERROR(
onnxruntime::cuda::PrepareForReduce(input_tensor, keepdims_, axes_span, metadata)
);
auto output_tensor = context->Output(0, metadata.squeezed_output_dims);

// Fast reduction is not deterministic, so sometimes we want to turn it off.
const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute();
return onnxruntime::cuda::ReduceComputeCore<T, CUDNN_REDUCE_TENSOR_NO_INDICES>(
/* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault),
*input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span,
/* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false,
enable_fast_but_non_deterministic_reduction, context->GetComputeStream());
}

template <typename T>
Status DistributedReduceMean<T>::ComputeInternal(OpKernelContext* context) const {
ORT_ENFORCE(context != nullptr);
return Status::OK();
const Tensor* input_tensor = context->Input<Tensor>(0);
const Tensor* axes_tensor = context->Input<Tensor>(1);
ORT_ENFORCE(axes_tensor != nullptr, "Axes input cannot be null. Please check the 2nd input.");
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "Axes tensor must be an 1-D tensor.");
auto axes_span = axes_tensor->DataAsSpan<int64_t>();

// Case 1: empty axes means treating this reduction as a no-op.
if (axes_span.empty()) {
auto* output_tensor = context->Output(0, input_tensor->Shape());
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor->MutableData<T>(), input_tensor->Data<T>(), input_tensor->SizeInBytes(),
cudaMemcpyDeviceToDevice, Stream(context)));
return Status::OK();
}

// Case 2: this is a valid reduction. Let's prepare for it.
onnxruntime::cuda::PrepareReduceMetadata metadata;
ORT_RETURN_IF_ERROR(
onnxruntime::cuda::PrepareForReduce(input_tensor, keepdims_, axes_span, metadata)
);
auto output_tensor = context->Output(0, metadata.squeezed_output_dims);

// Fast reduction is not deterministic, so sometimes we want to turn it off.
const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute();
return onnxruntime::cuda::ReduceComputeCore<T, CUDNN_REDUCE_TENSOR_NO_INDICES>(
/* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault),
*input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span,
/* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false,
enable_fast_but_non_deterministic_reduction, context->GetComputeStream());
}

template <typename T>
Status DistributedReduceMax<T>::ComputeInternal(OpKernelContext* context) const {
ORT_ENFORCE(context != nullptr);
return Status::OK();
const Tensor* input_tensor = context->Input<Tensor>(0);
const Tensor* axes_tensor = context->Input<Tensor>(1);
ORT_ENFORCE(axes_tensor != nullptr, "Axes input cannot be null. Please check the 2nd input.");
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "Axes tensor must be an 1-D tensor.");
auto axes_span = axes_tensor->DataAsSpan<int64_t>();

// Case 1: empty axes means treating this reduction as a no-op.
if (axes_span.empty()) {
auto* output_tensor = context->Output(0, input_tensor->Shape());
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor->MutableData<T>(), input_tensor->Data<T>(), input_tensor->SizeInBytes(),
cudaMemcpyDeviceToDevice, Stream(context)));
return Status::OK();
}

// Case 2: this is a valid reduction. Let's prepare for it.
onnxruntime::cuda::PrepareReduceMetadata metadata;
ORT_RETURN_IF_ERROR(
onnxruntime::cuda::PrepareForReduce(input_tensor, keepdims_, axes_span, metadata)
);
auto output_tensor = context->Output(0, metadata.squeezed_output_dims);

// Fast reduction is not deterministic, so sometimes we want to turn it off.
const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute();
return onnxruntime::cuda::ReduceComputeCore<T, CUDNN_REDUCE_TENSOR_NO_INDICES>(
/* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault),
*input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span,
/* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false,
enable_fast_but_non_deterministic_reduction, context->GetComputeStream());
}

// ReduceSum
ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceSum,
kMSDomain,
1,
int64_t,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>())
// Reduced axes are a small 1-D tensor, so we can use CPU memory.
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedReduceSum<int64_t>);
ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceSum,
kMSDomain,
Expand All @@ -82,16 +159,6 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceSum<MLFloat16>);

// ReduceMean
ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceMean,
kMSDomain,
1,
int64_t,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedReduceMean<int64_t>);
ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceMean,
kMSDomain,
Expand All @@ -114,16 +181,6 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceMean<MLFloat16>);

// ReduceMax
ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceMax,
kMSDomain,
1,
int64_t,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedReduceMax<int64_t>);
ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceMax,
kMSDomain,
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ class DistributedReduceSum final : public DistributedKernel {
explicit DistributedReduceSum(const OpKernelInfo& info);

Status ComputeInternal(OpKernelContext* context) const override;

private:
// ONNX attribute. If true, reduced axes are retained as dimensions with size one.
// Otherwise, drop reduced axes.
bool keepdims_;
cudnnReduceTensorOp_t cudnn_reduce_op_;
};

template <typename T>
Expand All @@ -34,6 +40,12 @@ class DistributedReduceMean final : public DistributedKernel {
explicit DistributedReduceMean(const OpKernelInfo& info);

Status ComputeInternal(OpKernelContext* context) const override;

private:
// ONNX attribute. If true, reduced axes are retained as dimensions with size one.
// Otherwise, drop reduced axes.
bool keepdims_;
cudnnReduceTensorOp_t cudnn_reduce_op_;
};

template <typename T>
Expand All @@ -42,6 +54,12 @@ class DistributedReduceMax final : public DistributedKernel {
explicit DistributedReduceMax(const OpKernelInfo& info);

Status ComputeInternal(OpKernelContext* context) const override;

private:
// ONNX attribute. If true, reduced axes are retained as dimensions with size one.
// Otherwise, drop reduced axes.
bool keepdims_;
cudnnReduceTensorOp_t cudnn_reduce_op_;
};

#endif
Expand Down
3 changes: 0 additions & 3 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,15 +365,12 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReduceSum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceSum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceSum)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMax)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReduceMean)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMean)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMean)>,
#endif
Expand Down
9 changes: 6 additions & 3 deletions onnxruntime/core/graph/contrib_ops/collective_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ void RegisterCollectiveOps() {
AttributeProto::STRINGS)
.Attr("keepdims",
"Keep the reduced dimension or not, default 1 mean keep reduced dimension.",
AttributeProto::INT)
AttributeProto::INT,
static_cast<int64_t>(1))
.Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Input(
1,
Expand Down Expand Up @@ -340,7 +341,8 @@ void RegisterCollectiveOps() {
AttributeProto::STRINGS)
.Attr("keepdims",
"Keep the reduced dimension or not, default 1 mean keep reduced dimension.",
AttributeProto::INT)
AttributeProto::INT,
static_cast<int64_t>(1))
.Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Input(
1,
Expand Down Expand Up @@ -380,7 +382,8 @@ void RegisterCollectiveOps() {
AttributeProto::STRINGS)
.Attr("keepdims",
"Keep the reduced dimension or not, default 1 mean keep reduced dimension.",
AttributeProto::INT)
AttributeProto::INT,
static_cast<int64_t>(1))
.Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Input(
1,
Expand Down
24 changes: 24 additions & 0 deletions onnxruntime/core/providers/cuda/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,30 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input,
return Status::OK();
}

template Status ReduceComputeCore<float, CUDNN_REDUCE_TENSOR_NO_INDICES>(
const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata,
/*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op,
gsl::span<const int64_t> axes,
bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction,
Stream* ort_stream,
const TensorShape* input_shape_override);

template Status ReduceComputeCore<double, CUDNN_REDUCE_TENSOR_NO_INDICES>(
const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata,
/*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op,
gsl::span<const int64_t> axes,
bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction,
Stream* ort_stream,
const TensorShape* input_shape_override);

template Status ReduceComputeCore<MLFloat16, CUDNN_REDUCE_TENSOR_NO_INDICES>(
const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata,
/*out*/ Tensor& output, cudnnReduceTensorOp_t cudnn_reduce_op,
gsl::span<const int64_t> axes,
bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction,
Stream* ort_stream,
const TensorShape* input_shape_override);

template <bool allow_multi_axes>
template <typename T, cudnnReduceTensorIndices_t ReduceTensorIndices>
Status ReduceKernel<allow_multi_axes>::ComputeImpl(OpKernelContext* ctx, cudnnReduceTensorOp_t cudnn_reduce_op) const {
Expand Down
4 changes: 0 additions & 4 deletions onnxruntime/test/python/onnxruntime_test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,6 @@ def distributed_reduce_mean_instance(data_tensor: FLOAT, axes_tensor: INT64):
):
data = np.random.randint(4, size=shape).astype(dtype)
expected = np_func(data, axis=axes, keepdims=bool(keepdims))
print("origin shape: ", data.shape)

assert len(input_shard_specs) == 2 and len(input_device_meshes) == 2, "Reduce has two inputs."
assert "S" not in input_shard_specs[1], "Tensor `axes` should not be sharded."
Expand All @@ -929,7 +928,6 @@ def distributed_reduce_mean_instance(data_tensor: FLOAT, axes_tensor: INT64):
local_data = shard_tensor_per_spec(data, rank, input_shard_specs[0], input_device_meshes[0])
local_expected = shard_tensor_per_spec(expected, rank, output_shard_specs[0], output_device_meshes[0])

print("local shape: ", local_data.shape)
if dtype == np.float32:
onnx_model = onnx_func.to_model_proto(
input_types=[FLOAT[tuple(local_data.shape)], INT64[len(axes)]],
Expand Down Expand Up @@ -958,8 +956,6 @@ def distributed_reduce_mean_instance(data_tensor: FLOAT, axes_tensor: INT64):
# Each MPI process executes its sharded model.
# The result is `local` tensor stored on a specific MPI rank
# instead of `logical` tensor.
print(local_data)
print(np.array(axes, dtype=np.int64))
result = sess.run(
None,
{
Expand Down

0 comments on commit c4cdf3e

Please sign in to comment.