Skip to content

Commit

Permalink
Allow more cases without resharding
Browse files Browse the repository at this point in the history
  • Loading branch information
wschin committed Oct 31, 2023
1 parent c4cdf3e commit 5014fb9
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 108 deletions.
133 changes: 45 additions & 88 deletions onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,117 +23,74 @@ namespace cuda {
#if defined(ORT_USE_NCCL)

template <typename T>
DistributedReduceSum<T>::DistributedReduceSum(const OpKernelInfo& info) : DistributedKernel(info) {
DistributedReduceBase<T>::DistributedReduceBase(
const OpKernelInfo& info,
cudnnReduceTensorOp_t cudnn_reduce_op) : DistributedKernel(info) {
keepdims_ = info.GetAttrOrDefault<int64_t>("keepdims", 1);
cudnn_reduce_op_ = CUDNN_REDUCE_TENSOR_ADD;
cudnn_reduce_op_ = cudnn_reduce_op;
};

template <typename T>
DistributedReduceMean<T>::DistributedReduceMean(const OpKernelInfo& info) : DistributedKernel(info) {
keepdims_ = info.GetAttrOrDefault<int64_t>("keepdims", 1);
cudnn_reduce_op_ = CUDNN_REDUCE_TENSOR_AVG;
};
Status DistributedReduceBase<T>::ComputeInternal(OpKernelContext* context) const {
const auto& input_sharding_spec = input_shard_specs_.at(0);
const auto& axes_sharding_spec = input_shard_specs_.at(1);
const auto& output_sharding_spec = output_shard_specs_.at(0);

template <typename T>
DistributedReduceMax<T>::DistributedReduceMax(const OpKernelInfo& info) : DistributedKernel(info) {
keepdims_ = info.GetAttrOrDefault<int64_t>("keepdims", 1);
cudnn_reduce_op_ = CUDNN_REDUCE_TENSOR_MAX;
};
ORT_ENFORCE(axes_sharding_spec.HasNoShard(),
"It's not worthy to shard axes tensor. "
"If sharding axes is needed, please submit a feature request.");

template <typename T>
Status DistributedReduceSum<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* input_tensor = context->Input<Tensor>(0);
const Tensor* axes_tensor = context->Input<Tensor>(1);
ORT_ENFORCE(axes_tensor != nullptr, "Axes input cannot be null. Please check the 2nd input.");
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "Axes tensor must be an 1-D tensor.");
auto axes_span = axes_tensor->DataAsSpan<int64_t>();

// Case 1: empty axes means treating this reduction as a no-op.
// Case 1: empty axes means treating this reduction as an identity.
if (axes_span.empty()) {
ORT_ENFORCE(
input_sharding_spec == output_sharding_spec,
"Input and output sharding specs should be the same. Otherwise, resharding is needed."
);
auto* output_tensor = context->Output(0, input_tensor->Shape());
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor->MutableData<T>(), input_tensor->Data<T>(), input_tensor->SizeInBytes(),
cudaMemcpyDeviceToDevice, Stream(context)));
return Status::OK();
}

// Case 2: this is a valid reduction. Let's prepare for it.
onnxruntime::cuda::PrepareReduceMetadata metadata;
ORT_RETURN_IF_ERROR(
onnxruntime::cuda::PrepareForReduce(input_tensor, keepdims_, axes_span, metadata)
);
auto output_tensor = context->Output(0, metadata.squeezed_output_dims);

// Fast reduction is not deterministic, so sometimes we want to turn it off.
const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute();
return onnxruntime::cuda::ReduceComputeCore<T, CUDNN_REDUCE_TENSOR_NO_INDICES>(
/* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault),
*input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span,
/* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false,
enable_fast_but_non_deterministic_reduction, context->GetComputeStream());
}

template <typename T>
Status DistributedReduceMean<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* input_tensor = context->Input<Tensor>(0);
const Tensor* axes_tensor = context->Input<Tensor>(1);
ORT_ENFORCE(axes_tensor != nullptr, "Axes input cannot be null. Please check the 2nd input.");
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "Axes tensor must be an 1-D tensor.");
auto axes_span = axes_tensor->DataAsSpan<int64_t>();

// Case 1: empty axes means treating this reduction as a no-op.
if (axes_span.empty()) {
auto* output_tensor = context->Output(0, input_tensor->Shape());
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor->MutableData<T>(), input_tensor->Data<T>(), input_tensor->SizeInBytes(),
cudaMemcpyDeviceToDevice, Stream(context)));
return Status::OK();
bool sharding_on_reduced_axes = false;
for (auto axis_it = axes_span.begin(); input_sharding_spec.HasShard() && axis_it != axes_span.end(); ++axis_it) {
if (*axis_it == input_sharding_spec.GetPartitionAxis()) {
sharding_on_reduced_axes = true;
break;
}
}

// Case 2: this is a valid reduction. Let's prepare for it.
onnxruntime::cuda::PrepareReduceMetadata metadata;
ORT_RETURN_IF_ERROR(
onnxruntime::cuda::PrepareForReduce(input_tensor, keepdims_, axes_span, metadata)
);
auto output_tensor = context->Output(0, metadata.squeezed_output_dims);

// Fast reduction is not deterministic, so sometimes we want to turn it off.
const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute();
return onnxruntime::cuda::ReduceComputeCore<T, CUDNN_REDUCE_TENSOR_NO_INDICES>(
/* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault),
*input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span,
/* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false,
enable_fast_but_non_deterministic_reduction, context->GetComputeStream());
}

template <typename T>
Status DistributedReduceMax<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* input_tensor = context->Input<Tensor>(0);
const Tensor* axes_tensor = context->Input<Tensor>(1);
ORT_ENFORCE(axes_tensor != nullptr, "Axes input cannot be null. Please check the 2nd input.");
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "Axes tensor must be an 1-D tensor.");
auto axes_span = axes_tensor->DataAsSpan<int64_t>();

// Case 1: empty axes means treating this reduction as a no-op.
if (axes_span.empty()) {
auto* output_tensor = context->Output(0, input_tensor->Shape());
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor->MutableData<T>(), input_tensor->Data<T>(), input_tensor->SizeInBytes(),
cudaMemcpyDeviceToDevice, Stream(context)));
return Status::OK();
if (sharding_on_reduced_axes) {
// Case 2-1: sharding on reduced axes.
ORT_THROW(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "Not implemented. Resharding is required to make reduced axes replica.");
} else {
// Case 2-2: sharding on passing-through axes or no shard.
ORT_ENFORCE(
input_sharding_spec == output_sharding_spec,
"Input and output sharding specs should be the same. Otherwise, resharding is needed."
);
onnxruntime::cuda::PrepareReduceMetadata metadata;
ORT_RETURN_IF_ERROR(
onnxruntime::cuda::PrepareForReduce(input_tensor, keepdims_, axes_span, metadata)
);
auto output_tensor = context->Output(0, metadata.squeezed_output_dims);

// Fast reduction is not deterministic, so sometimes we want to turn it off.
const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute();
return onnxruntime::cuda::ReduceComputeCore<T, CUDNN_REDUCE_TENSOR_NO_INDICES>(
/* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault),
*input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span,
/* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false,
enable_fast_but_non_deterministic_reduction, context->GetComputeStream());
}

// Case 2: this is a valid reduction. Let's prepare for it.
onnxruntime::cuda::PrepareReduceMetadata metadata;
ORT_RETURN_IF_ERROR(
onnxruntime::cuda::PrepareForReduce(input_tensor, keepdims_, axes_span, metadata)
);
auto output_tensor = context->Output(0, metadata.squeezed_output_dims);

// Fast reduction is not deterministic, so sometimes we want to turn it off.
const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute();
return onnxruntime::cuda::ReduceComputeCore<T, CUDNN_REDUCE_TENSOR_NO_INDICES>(
/* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault),
*input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span,
/* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false,
enable_fast_but_non_deterministic_reduction, context->GetComputeStream());
return Status::OK();
}

// ReduceSum
Expand Down
30 changes: 10 additions & 20 deletions onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ namespace cuda {
#if defined(ORT_USE_NCCL)

template <typename T>
class DistributedReduceSum final : public DistributedKernel {
class DistributedReduceBase : public DistributedKernel {
public:
explicit DistributedReduceSum(const OpKernelInfo& info);
explicit DistributedReduceBase(const OpKernelInfo& info, cudnnReduceTensorOp_t cudnn_reduce_op);

Status ComputeInternal(OpKernelContext* context) const override;

Expand All @@ -35,31 +35,21 @@ class DistributedReduceSum final : public DistributedKernel {
};

template <typename T>
class DistributedReduceMean final : public DistributedKernel {
class DistributedReduceSum final : public DistributedReduceBase {
public:
explicit DistributedReduceMean(const OpKernelInfo& info);

Status ComputeInternal(OpKernelContext* context) const override;
explicit DistributedReduceSum(const OpKernelInfo& info);
};

private:
// ONNX attribute. If true, reduced axes are retained as dimensions with size one.
// Otherwise, drop reduced axes.
bool keepdims_;
cudnnReduceTensorOp_t cudnn_reduce_op_;
template <typename T>
class DistributedReduceMean final : public DistributedReduceBase {
public:
explicit DistributedReduceMean(const OpKernelInfo& info);
};

template <typename T>
class DistributedReduceMax final : public DistributedKernel {
class DistributedReduceMax final : public DistributedReduceBase {
public:
explicit DistributedReduceMax(const OpKernelInfo& info);

Status ComputeInternal(OpKernelContext* context) const override;

private:
// ONNX attribute. If true, reduced axes are retained as dimensions with size one.
// Otherwise, drop reduced axes.
bool keepdims_;
cudnnReduceTensorOp_t cudnn_reduce_op_;
};

#endif
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/test/python/onnxruntime_test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,21 @@ def test_reduce(self):
output_shard_specs=("RR",),
)

def test_reduce_sharded(self):
self._check_distributed_reduce(
keepdims=1,
dtype=np.float32,
shape=(
8,
4,
),
axes=(1,),
input_device_meshes=[np.array([0, 1])] * 2,
input_shard_specs=("S[0]R", "R"),
output_device_meshes=[np.array([0, 1])],
output_shard_specs=("S[0]R",),
)


class TestDistributed(unittest.TestCase):
def test_matmul_rs_sr_rr(self):
Expand Down

0 comments on commit 5014fb9

Please sign in to comment.