forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Distributed Reduction (microsoft#18206)
This PR implements distributed reduciton for llama 2. This version doesn't consider any cases requring re-sharding because we haven't seen any use cases. Intutive examples: - [supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] -> Reduce(axes=[0]) -> [1,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] - [supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] -> Reduce(axes=[1]) -> [2,1,6]-tensor with spec=RRS[0] and device_mesh=[0,1] - [not supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] -> Reduce(axes=[2]) -> [2,4,1]-tensor with spec=RRS[0] and device_mesh=[0,1] Algorithm: When the reduced axes are not sharded, each device can call reduction directly. The output sharding spec will be identical to input sharding spec. We currently throw when input and output sharding specs are different. Review guideline: - Check 97b8d2f for new op's schema and how new op is registered. - Read tests in 2450f93 to get faimilar with the behavior of these ops. - Check the implementation details in 753d9af.
- Loading branch information
Showing
8 changed files
with
638 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
175 changes: 175 additions & 0 deletions
175
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
|
||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
// Distributed computation. | ||
#include "distributed_reduce.h" | ||
#include "sharding.h" | ||
#include "sharding_spec.h" | ||
#include "nccl_kernels.h" | ||
#include "mpi_include.h" | ||
|
||
// ORT system. | ||
#include "core/providers/cuda/cudnn_common.h" | ||
#include "core/providers/cuda/reduction/reduction_ops.h" | ||
|
||
// std C++. | ||
#include <iostream> | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace cuda { | ||
|
||
#if defined(ORT_USE_NCCL) | ||
|
||
template <typename T> | ||
DistributedReduceBase<T>::DistributedReduceBase( | ||
const OpKernelInfo& info, | ||
cudnnReduceTensorOp_t cudnn_reduce_op) : DistributedKernel(info) { | ||
keepdims_ = info.GetAttrOrDefault<int64_t>("keepdims", 1); | ||
cudnn_reduce_op_ = cudnn_reduce_op; | ||
}; | ||
|
||
template <typename T> | ||
Status DistributedReduceBase<T>::ComputeInternal(OpKernelContext* context) const { | ||
const auto& input_sharding_spec = input_shard_specs_.at(0); | ||
const auto& axes_sharding_spec = input_shard_specs_.at(1); | ||
const auto& output_sharding_spec = output_shard_specs_.at(0); | ||
|
||
ORT_ENFORCE(axes_sharding_spec.HasNoShard(), | ||
"It's not worthy to shard axes tensor. " | ||
"If sharding axes is needed, please submit a feature request."); | ||
|
||
const Tensor* input_tensor = context->Input<Tensor>(0); | ||
const Tensor* axes_tensor = context->Input<Tensor>(1); | ||
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "Axes tensor must be an 1-D tensor."); | ||
auto axes_span = axes_tensor->DataAsSpan<int64_t>(); | ||
|
||
// Case 1: empty axes means treating this reduction as an identity. | ||
if (axes_span.empty()) { | ||
ORT_ENFORCE( | ||
input_sharding_spec == output_sharding_spec, | ||
"Input and output sharding specs should be the same. Otherwise, resharding is needed."); | ||
auto* output_tensor = context->Output(0, input_tensor->Shape()); | ||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor->MutableData<T>(), input_tensor->Data<T>(), input_tensor->SizeInBytes(), | ||
cudaMemcpyDeviceToDevice, Stream(context))); | ||
return Status::OK(); | ||
} | ||
|
||
// Case 2: this is a valid reduction. Let's prepare for it. | ||
|
||
bool sharding_on_reduced_axes = false; | ||
for (auto axis_it = axes_span.begin(); input_sharding_spec.HasShard() && axis_it != axes_span.end(); ++axis_it) { | ||
if (*axis_it == input_sharding_spec.GetPartitionAxis()) { | ||
sharding_on_reduced_axes = true; | ||
break; | ||
} | ||
} | ||
|
||
if (sharding_on_reduced_axes) { | ||
// Case 2-1: sharding on reduced axes. | ||
ORT_THROW(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "Not implemented. Resharding is required to make reduced axes replica."); | ||
} else { | ||
// Case 2-2: sharding on passing-through axes or no shard. | ||
ORT_ENFORCE( | ||
input_sharding_spec == output_sharding_spec, | ||
"Input and output sharding specs should be the same. Otherwise, resharding is needed."); | ||
onnxruntime::cuda::PrepareReduceMetadata metadata; | ||
ORT_RETURN_IF_ERROR( | ||
onnxruntime::cuda::PrepareForReduce(input_tensor, keepdims_, axes_span, metadata)); | ||
auto output_tensor = context->Output(0, metadata.squeezed_output_dims); | ||
|
||
// Fast reduction is not deterministic, so sometimes we want to turn it off. | ||
const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute(); | ||
return onnxruntime::cuda::ReduceComputeCore<T, CUDNN_REDUCE_TENSOR_NO_INDICES>( | ||
/* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault), | ||
*input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span, | ||
/* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false, | ||
enable_fast_but_non_deterministic_reduction, context->GetComputeStream()); | ||
} | ||
return Status::OK(); | ||
} | ||
|
||
template <typename T> | ||
DistributedReduceSum<T>::DistributedReduceSum( | ||
const OpKernelInfo& info) : DistributedReduceBase<T>(info, CUDNN_REDUCE_TENSOR_ADD){}; | ||
|
||
template <typename T> | ||
DistributedReduceMean<T>::DistributedReduceMean( | ||
const OpKernelInfo& info) : DistributedReduceBase<T>(info, CUDNN_REDUCE_TENSOR_AVG){}; | ||
|
||
template <typename T> | ||
DistributedReduceMax<T>::DistributedReduceMax( | ||
const OpKernelInfo& info) : DistributedReduceBase<T>(info, CUDNN_REDUCE_TENSOR_MAX){}; | ||
|
||
// ReduceSum | ||
ONNX_OPERATOR_TYPED_KERNEL_EX( | ||
DistributedReduceSum, | ||
kMSDomain, | ||
1, | ||
float, | ||
kCudaExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()) | ||
.InputMemoryType(OrtMemTypeCPUInput, 1), | ||
DistributedReduceSum<float>); | ||
ONNX_OPERATOR_TYPED_KERNEL_EX( | ||
DistributedReduceSum, | ||
kMSDomain, | ||
1, | ||
MLFloat16, | ||
kCudaExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()) | ||
.InputMemoryType(OrtMemTypeCPUInput, 1), | ||
DistributedReduceSum<MLFloat16>); | ||
|
||
// ReduceMean | ||
ONNX_OPERATOR_TYPED_KERNEL_EX( | ||
DistributedReduceMean, | ||
kMSDomain, | ||
1, | ||
float, | ||
kCudaExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()) | ||
.InputMemoryType(OrtMemTypeCPUInput, 1), | ||
DistributedReduceMean<float>); | ||
ONNX_OPERATOR_TYPED_KERNEL_EX( | ||
DistributedReduceMean, | ||
kMSDomain, | ||
1, | ||
MLFloat16, | ||
kCudaExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()) | ||
.InputMemoryType(OrtMemTypeCPUInput, 1), | ||
DistributedReduceMean<MLFloat16>); | ||
|
||
// ReduceMax | ||
ONNX_OPERATOR_TYPED_KERNEL_EX( | ||
DistributedReduceMax, | ||
kMSDomain, | ||
1, | ||
float, | ||
kCudaExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()) | ||
.InputMemoryType(OrtMemTypeCPUInput, 1), | ||
DistributedReduceMax<float>); | ||
ONNX_OPERATOR_TYPED_KERNEL_EX( | ||
DistributedReduceMax, | ||
kMSDomain, | ||
1, | ||
MLFloat16, | ||
kCudaExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()) | ||
.InputMemoryType(OrtMemTypeCPUInput, 1), | ||
DistributedReduceMax<MLFloat16>); | ||
|
||
#endif | ||
|
||
} // namespace cuda | ||
} // namespace contrib | ||
} // namespace onnxruntime |
59 changes: 59 additions & 0 deletions
59
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "sharding_spec.h" | ||
#include "sharding.h" | ||
#include "core/providers/cuda/cuda_kernel.h" | ||
|
||
#include <algorithm> | ||
#include <tuple> | ||
#include <optional> | ||
#include <string> | ||
#include <nccl.h> | ||
#include <sstream> | ||
|
||
#pragma once | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace cuda { | ||
|
||
#if defined(ORT_USE_NCCL) | ||
|
||
template <typename T> | ||
class DistributedReduceBase : public DistributedKernel { | ||
public: | ||
explicit DistributedReduceBase(const OpKernelInfo& info, cudnnReduceTensorOp_t cudnn_reduce_op); | ||
|
||
Status ComputeInternal(OpKernelContext* context) const override; | ||
|
||
private: | ||
// ONNX attribute. If true, reduced axes are retained as dimensions with size one. | ||
// Otherwise, drop reduced axes. | ||
bool keepdims_; | ||
cudnnReduceTensorOp_t cudnn_reduce_op_; | ||
}; | ||
|
||
template <typename T> | ||
class DistributedReduceSum final : public DistributedReduceBase<T> { | ||
public: | ||
explicit DistributedReduceSum(const OpKernelInfo& info); | ||
}; | ||
|
||
template <typename T> | ||
class DistributedReduceMean final : public DistributedReduceBase<T> { | ||
public: | ||
explicit DistributedReduceMean(const OpKernelInfo& info); | ||
}; | ||
|
||
template <typename T> | ||
class DistributedReduceMax final : public DistributedReduceBase<T> { | ||
public: | ||
explicit DistributedReduceMax(const OpKernelInfo& info); | ||
}; | ||
|
||
#endif | ||
|
||
} // namespace cuda | ||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.