Skip to content

Commit

Permalink
Skeleton of DistributedExpand
Browse files Browse the repository at this point in the history
  • Loading branch information
wschin committed Oct 27, 2023
1 parent 3d0db47 commit 71e242e
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 0 deletions.
1 change: 1 addition & 0 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reshape.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_expand.cc"
)
endif()
# add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio
Expand Down
1 change: 1 addition & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ if (NOT onnxruntime_USE_NCCL)
list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc")
endif()

set(provider_excluded_files
Expand Down
69 changes: 69 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// Distributed computation.
#include "distributed_expand.h"
#include "sharding.h"
#include "sharding_spec.h"
#include "nccl_kernels.h"
#include "mpi_include.h"

// ORT system.
#include "core/providers/cuda/tensor/expand.h"

// std C++.
#include <iostream>

namespace onnxruntime {
namespace contrib {
namespace cuda {

#if defined(ORT_USE_NCCL)

template <typename T>
DistributedExpand<T>::DistributedExpand(const OpKernelInfo& info) : DistributedKernel(info) {}

template <typename T>
Status DistributedExpand<T>::ComputeInternal(OpKernelContext* context) const {
ORT_ENFORCE(context != nullptr);
return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Encounter unsupported expand pattern.");
}

ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedExpand,
kMSDomain,
1,
int64_t,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedExpand<int64_t>);

ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedExpand,
kMSDomain,
1,
float,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedExpand<float>);

ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedExpand,
kMSDomain,
1,
MLFloat16,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedExpand<MLFloat16>);

#endif

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
35 changes: 35 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/distributed_expand.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "sharding_spec.h"
#include "sharding.h"
#include "core/providers/cuda/cuda_kernel.h"

#include <algorithm>
#include <tuple>
#include <optional>
#include <string>
#include <nccl.h>
#include <sstream>

#pragma once

namespace onnxruntime {
namespace contrib {
namespace cuda {

#if defined(ORT_USE_NCCL)

template <typename T>
class DistributedExpand final : public DistributedKernel {
public:
explicit DistributedExpand(const OpKernelInfo& info);

Status ComputeInternal(OpKernelContext* context) const override;
};

#endif

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
8 changes: 8 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReshape);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReshape);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReshape);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand);
#endif

template <>
Expand Down Expand Up @@ -342,6 +346,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReshape)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand)>,
#endif

};
Expand Down
37 changes: 37 additions & 0 deletions onnxruntime/core/graph/contrib_ops/collective_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,43 @@ void RegisterCollectiveOps() {
OpSchema::NonDifferentiable)
.Output(0, "reshaped", "Reshaped data.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types.");

ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedExpand)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Attr("input_device_mesh_elements",
"device_mesh_elements[i] defines the device mesh's value for the i-th input. "
"E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd "
" inputs are stored on the 0-th and the 1st devices, respectively.",
AttributeProto::STRINGS)
.Attr("input_device_mesh_shapes",
"device_mesh_shape[i] defines the device mesh's shape for the i-th input.",
AttributeProto::STRINGS)
.Attr("input_shard_specs",
"The sharding spec of inputs. "
"E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.",
AttributeProto::STRINGS)
.Attr("output_device_mesh_elements",
"Similar to input_device_mesh_elments but for outputs.",
AttributeProto::STRINGS)
.Attr("output_device_mesh_shapes",
"Similar to input_device_mesh_shapes but for outputs.",
AttributeProto::STRINGS)
.Attr("output_shard_specs",
"Similar to input_shard_specs but for outputs.",
AttributeProto::STRINGS)
.Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Input(
1,
"shape",
"A 1-D tensor indicates the shape you want to expand to, following the broadcast rule",
"tensor(int64)",
OpSchema::Single,
true,
1,
OpSchema::NonDifferentiable)
.Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors.");
}

} // namespace contrib
Expand Down

0 comments on commit 71e242e

Please sign in to comment.