forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Distributed Expand (microsoft#18126)
This PR implements DistributedExpand for llama 2. Representative Examples of DistributedExpand: - [shard on non-expanded axis] `input tensor (shape=[8, 1], spec=S[0]R, device_mesh=[0,1]) -> Expand(target_shape=[8, 2] -> output tensor (shape=[8, 2], spec=S[0]R, device_mesh=[0,1])` - [sharding expanded axis is invalid since it must have dim=1 and axis with dim=1 cannot be sharded] `input tensor (shape=[1, 8], spec=S[0]R, device_mesh=[0,1]) -> Expand(target_shape=[2, 8] -> output tensor (shape=[2, 8], spec=S[0]R, device_mesh=[0,1])` From those examples, we observe a few important behaviors. - The output sharding spec is always the same to the input sharding spec. - Expanding always happen on axis with dimension=1. Otherwise, it will violate the broadcasting rule. - No communication is needed since all computation can happen locally. Let's consider the first example again. If you put the first half tensor (shape: [4, 1]) on device 0 and the second half (shape: [4, 1]) on device 1, then `Expand` it with target shape [4, 2] , these two local tensors (shape: [4, 2]) are exactly the same as the one described by output sharding spec. Algorithm: - Compute logical (i.e., unsharded) shapes of input and output. - Compute sharded output shape from logical output. - Call Expand to broadcast local input to sharded output shape. How to review? - Start with [changes in onnxruntime_test_distributed.py](microsoft@ea33392). Those tests are good examples for using this op. - [Read expand.h/expand.cc](microsoft@e4c4998). Theose changes are for exposing functionalities in Expand to DistributedExpand. - Read distributed_expand.h/distributed_expand.cc. It follows the algorithm described above. The commit microsoft@68ac301 first sketches the definition of DistributedExpand. The next commit microsoft@0eb9330 adds real implementation.
- Loading branch information
Showing
9 changed files
with
470 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
110 changes: 110 additions & 0 deletions
110
onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
// Distributed computation. | ||
#include "distributed_expand.h" | ||
#include "sharding.h" | ||
#include "sharding_spec.h" | ||
#include "nccl_kernels.h" | ||
#include "mpi_include.h" | ||
|
||
// ORT system. | ||
#include "core/providers/cuda/tensor/expand.h" | ||
|
||
// std C++. | ||
#include <iostream> | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace cuda { | ||
|
||
#if defined(ORT_USE_NCCL) | ||
|
||
template <typename T> | ||
DistributedExpand<T>::DistributedExpand(const OpKernelInfo& info) : DistributedKernel(info) {} | ||
|
||
template <typename T> | ||
Status DistributedExpand<T>::ComputeInternal(OpKernelContext* context) const { | ||
ORT_ENFORCE(context != nullptr); | ||
// Assumptions. | ||
// - Shape is not sharded. | ||
// Algorithm. | ||
// - Compute logical output shape. | ||
// - Compute local output shape. | ||
// - Expand from local input to local output. | ||
|
||
auto input_tensor = context->Input<Tensor>(0); | ||
auto shape_tensor = context->Input<Tensor>(1); | ||
const auto& input_sharding_spec = input_shard_specs_.at(0); | ||
const auto& shape_sharding_spec = input_shard_specs_.at(1); | ||
const auto& output_sharding_spec = output_shard_specs_.at(0); | ||
|
||
ORT_ENFORCE(shape_sharding_spec.HasNoShard(), | ||
"It's not worth to shard Shape tensor. " | ||
"If sharding shape is needed, please submit a feature request."); | ||
// Compute logical input shape. | ||
const auto original_input_shape = ComputeOriginShape(input_tensor->Shape(), input_sharding_spec); | ||
|
||
// Compute logical output shape. | ||
// This `shape_tensor` stores the logical output shape. | ||
const auto* p_shape = shape_tensor->Data<int64_t>(); | ||
TensorShapeVector original_output_dims{p_shape, p_shape + shape_tensor->Shape().Size()}; | ||
TensorShape original_output_shape(original_output_dims); | ||
ORT_ENFORCE( | ||
onnxruntime::cuda::ComputeOutputShape( | ||
Node().Name(), | ||
original_input_shape, | ||
original_output_dims, original_output_shape) | ||
.IsOK()); | ||
|
||
// Compute local output shape. | ||
const auto local_output_shape = ComputeShardShape(original_output_shape, output_sharding_spec); | ||
|
||
auto output_tensor = context->Output(0, local_output_shape); | ||
|
||
return FuncExpand( | ||
this, | ||
context, | ||
input_tensor, | ||
shape_tensor, | ||
output_tensor); | ||
} | ||
|
||
ONNX_OPERATOR_TYPED_KERNEL_EX( | ||
DistributedExpand, | ||
kMSDomain, | ||
1, | ||
int64_t, | ||
kCudaExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>()) | ||
.InputMemoryType(OrtMemTypeCPUInput, 1), | ||
DistributedExpand<int64_t>); | ||
|
||
ONNX_OPERATOR_TYPED_KERNEL_EX( | ||
DistributedExpand, | ||
kMSDomain, | ||
1, | ||
float, | ||
kCudaExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()) | ||
.InputMemoryType(OrtMemTypeCPUInput, 1), | ||
DistributedExpand<float>); | ||
|
||
ONNX_OPERATOR_TYPED_KERNEL_EX( | ||
DistributedExpand, | ||
kMSDomain, | ||
1, | ||
MLFloat16, | ||
kCudaExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()) | ||
.InputMemoryType(OrtMemTypeCPUInput, 1), | ||
DistributedExpand<MLFloat16>); | ||
|
||
#endif | ||
|
||
} // namespace cuda | ||
} // namespace contrib | ||
} // namespace onnxruntime |
35 changes: 35 additions & 0 deletions
35
onnxruntime/contrib_ops/cuda/collective/distributed_expand.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "sharding_spec.h" | ||
#include "sharding.h" | ||
#include "core/providers/cuda/cuda_kernel.h" | ||
|
||
#include <algorithm> | ||
#include <tuple> | ||
#include <optional> | ||
#include <string> | ||
#include <nccl.h> | ||
#include <sstream> | ||
|
||
#pragma once | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace cuda { | ||
|
||
#if defined(ORT_USE_NCCL) | ||
|
||
template <typename T> | ||
class DistributedExpand final : public DistributedKernel { | ||
public: | ||
explicit DistributedExpand(const OpKernelInfo& info); | ||
|
||
Status ComputeInternal(OpKernelContext* context) const override; | ||
}; | ||
|
||
#endif | ||
|
||
} // namespace cuda | ||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.