Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed Reduction #18206

Merged
merged 6 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reshape.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_expand.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reduce.cc"
)
endif()
# add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio
Expand Down
1 change: 1 addition & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ if (NOT onnxruntime_USE_NCCL)
list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_reduce.cc")
endif()

set(provider_excluded_files
Expand Down
175 changes: 175 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@

Fixed Show fixed Hide fixed
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// Distributed computation.
#include "distributed_reduce.h"

Check warning on line 6 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc#L6

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc:6:  Include the directory when naming header files  [build/include_subdir] [4]
#include "sharding.h"

Check warning on line 7 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc#L7

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc:7:  Include the directory when naming header files  [build/include_subdir] [4]
#include "sharding_spec.h"

Check warning on line 8 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc#L8

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc:8:  Include the directory when naming header files  [build/include_subdir] [4]
#include "nccl_kernels.h"

Check warning on line 9 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc#L9

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc:9:  Include the directory when naming header files  [build/include_subdir] [4]
#include "mpi_include.h"

Check warning on line 10 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc#L10

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc:10:  Include the directory when naming header files  [build/include_subdir] [4]

// ORT system.
#include "core/providers/cuda/cudnn_common.h"
#include "core/providers/cuda/reduction/reduction_ops.h"

// std C++.
#include <iostream>

Check warning on line 17 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc#L17

Found C++ system header after other header. Should be: distributed_reduce.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc:17:  Found C++ system header after other header. Should be: distributed_reduce.h, c system, c++ system, other.  [build/include_order] [4]

namespace onnxruntime {
namespace contrib {
namespace cuda {

#if defined(ORT_USE_NCCL)

template <typename T>
DistributedReduceBase<T>::DistributedReduceBase(
const OpKernelInfo& info,
cudnnReduceTensorOp_t cudnn_reduce_op) : DistributedKernel(info) {
keepdims_ = info.GetAttrOrDefault<int64_t>("keepdims", 1);
cudnn_reduce_op_ = cudnn_reduce_op;
};

Check warning on line 31 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc#L31

You don't need a ; after a } [readability/braces] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc:31:  You don't need a ; after a }  [readability/braces] [4]

template <typename T>
Status DistributedReduceBase<T>::ComputeInternal(OpKernelContext* context) const {
const auto& input_sharding_spec = input_shard_specs_.at(0);
const auto& axes_sharding_spec = input_shard_specs_.at(1);
const auto& output_sharding_spec = output_shard_specs_.at(0);

ORT_ENFORCE(axes_sharding_spec.HasNoShard(),
"It's not worthy to shard axes tensor. "
"If sharding axes is needed, please submit a feature request.");

const Tensor* input_tensor = context->Input<Tensor>(0);
const Tensor* axes_tensor = context->Input<Tensor>(1);
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "Axes tensor must be an 1-D tensor.");
auto axes_span = axes_tensor->DataAsSpan<int64_t>();

// Case 1: empty axes means treating this reduction as an identity.
if (axes_span.empty()) {
ORT_ENFORCE(
input_sharding_spec == output_sharding_spec,
"Input and output sharding specs should be the same. Otherwise, resharding is needed.");
auto* output_tensor = context->Output(0, input_tensor->Shape());
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor->MutableData<T>(), input_tensor->Data<T>(), input_tensor->SizeInBytes(),

Check warning on line 54 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc#L54

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc:54:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
cudaMemcpyDeviceToDevice, Stream(context)));
return Status::OK();
}

// Case 2: this is a valid reduction. Let's prepare for it.

bool sharding_on_reduced_axes = false;
for (auto axis_it = axes_span.begin(); input_sharding_spec.HasShard() && axis_it != axes_span.end(); ++axis_it) {
if (*axis_it == input_sharding_spec.GetPartitionAxis()) {
sharding_on_reduced_axes = true;
break;
}
}

if (sharding_on_reduced_axes) {
// Case 2-1: sharding on reduced axes.
ORT_THROW(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::FAIL, "Not implemented. Resharding is required to make reduced axes replica.");

Check warning on line 71 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc#L71

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc:71:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
} else {
// Case 2-2: sharding on passing-through axes or no shard.
ORT_ENFORCE(
input_sharding_spec == output_sharding_spec,
"Input and output sharding specs should be the same. Otherwise, resharding is needed.");
onnxruntime::cuda::PrepareReduceMetadata metadata;
ORT_RETURN_IF_ERROR(
onnxruntime::cuda::PrepareForReduce(input_tensor, keepdims_, axes_span, metadata));
auto output_tensor = context->Output(0, metadata.squeezed_output_dims);

// Fast reduction is not deterministic, so sometimes we want to turn it off.
const bool enable_fast_but_non_deterministic_reduction = !context->GetUseDeterministicCompute();
return onnxruntime::cuda::ReduceComputeCore<T, CUDNN_REDUCE_TENSOR_NO_INDICES>(
/* GPU allocator */ Info().GetAllocator(OrtMemType::OrtMemTypeDefault),
*input_tensor, metadata, *output_tensor, cudnn_reduce_op_, axes_span,
/* calculate_log */ false, /* calculate_sqt */ false, /* log_sum_exp_ */ false,
enable_fast_but_non_deterministic_reduction, context->GetComputeStream());
}
return Status::OK();
}

template <typename T>
DistributedReduceSum<T>::DistributedReduceSum(
const OpKernelInfo& info) : DistributedReduceBase<T>(info, CUDNN_REDUCE_TENSOR_ADD) {};

Check warning on line 95 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc#L95

You don't need a ; after a } [readability/braces] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc:95:  You don't need a ; after a }  [readability/braces] [4]

template <typename T>
DistributedReduceMean<T>::DistributedReduceMean(
const OpKernelInfo& info) : DistributedReduceBase<T>(info, CUDNN_REDUCE_TENSOR_AVG) {};

Check warning on line 99 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc#L99

You don't need a ; after a } [readability/braces] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc:99:  You don't need a ; after a }  [readability/braces] [4]

template <typename T>
DistributedReduceMax<T>::DistributedReduceMax(
const OpKernelInfo& info) : DistributedReduceBase<T>(info, CUDNN_REDUCE_TENSOR_MAX) {};

Check warning on line 103 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc#L103

You don't need a ; after a } [readability/braces] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.cc:103:  You don't need a ; after a }  [readability/braces] [4]

// ReduceSum
ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceSum,
kMSDomain,
1,
float,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedReduceSum<float>);
ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceSum,
kMSDomain,
1,
MLFloat16,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedReduceSum<MLFloat16>);

// ReduceMean
ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceMean,
kMSDomain,
1,
float,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedReduceMean<float>);
ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceMean,
kMSDomain,
1,
MLFloat16,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedReduceMean<MLFloat16>);

// ReduceMax
ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceMax,
kMSDomain,
1,
float,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedReduceMax<float>);
ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedReduceMax,
kMSDomain,
1,
MLFloat16,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedReduceMax<MLFloat16>);

#endif

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
59 changes: 59 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "sharding_spec.h"

Check warning on line 4 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h#L4

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h:4:  Include the directory when naming header files  [build/include_subdir] [4]
#include "sharding.h"

Check warning on line 5 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h#L5

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h:5:  Include the directory when naming header files  [build/include_subdir] [4]
#include "core/providers/cuda/cuda_kernel.h"

#include <algorithm>

Check warning on line 8 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h#L8

Found C++ system header after other header. Should be: distributed_reduce.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h:8:  Found C++ system header after other header. Should be: distributed_reduce.h, c system, c++ system, other.  [build/include_order] [4]
#include <tuple>

Check warning on line 9 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h#L9

Found C++ system header after other header. Should be: distributed_reduce.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h:9:  Found C++ system header after other header. Should be: distributed_reduce.h, c system, c++ system, other.  [build/include_order] [4]
#include <optional>

Check warning on line 10 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h#L10

Found C++ system header after other header. Should be: distributed_reduce.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h:10:  Found C++ system header after other header. Should be: distributed_reduce.h, c system, c++ system, other.  [build/include_order] [4]
#include <string>

Check warning on line 11 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h#L11

Found C++ system header after other header. Should be: distributed_reduce.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h:11:  Found C++ system header after other header. Should be: distributed_reduce.h, c system, c++ system, other.  [build/include_order] [4]
#include <nccl.h>

Check warning on line 12 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h#L12

Found C system header after other header. Should be: distributed_reduce.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h:12:  Found C system header after other header. Should be: distributed_reduce.h, c system, c++ system, other.  [build/include_order] [4]
#include <sstream>

Check warning on line 13 in onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h#L13

Found C++ system header after other header. Should be: distributed_reduce.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_reduce.h:13:  Found C++ system header after other header. Should be: distributed_reduce.h, c system, c++ system, other.  [build/include_order] [4]

#pragma once

namespace onnxruntime {
namespace contrib {
namespace cuda {

#if defined(ORT_USE_NCCL)

template <typename T>
class DistributedReduceBase : public DistributedKernel {
public:
explicit DistributedReduceBase(const OpKernelInfo& info, cudnnReduceTensorOp_t cudnn_reduce_op);

Status ComputeInternal(OpKernelContext* context) const override;

private:
// ONNX attribute. If true, reduced axes are retained as dimensions with size one.
// Otherwise, drop reduced axes.
bool keepdims_;
cudnnReduceTensorOp_t cudnn_reduce_op_;
};

template <typename T>
class DistributedReduceSum final : public DistributedReduceBase<T> {
public:
explicit DistributedReduceSum(const OpKernelInfo& info);
};

template <typename T>
class DistributedReduceMean final : public DistributedReduceBase<T> {
public:
explicit DistributedReduceMean(const OpKernelInfo& info);
};

template <typename T>
class DistributedReduceMax final : public DistributedReduceBase<T> {
public:
explicit DistributedReduceMax(const OpKernelInfo& info);
};

#endif

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
18 changes: 18 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,15 @@
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceSum);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceSum);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMax);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMean);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMean);
#endif

template <>
Expand Down Expand Up @@ -352,6 +361,15 @@
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand)>,

Check warning on line 364 in onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc#L364

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc:364:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceSum)>,

Check warning on line 365 in onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc#L365

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc:365:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceSum)>,

Check warning on line 367 in onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc#L367

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc:367:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMax)>,

Check warning on line 368 in onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc#L368

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc:368:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMax)>,

Check warning on line 370 in onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc#L370

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc:370:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReduceMean)>,

Check warning on line 371 in onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc#L371

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc:371:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReduceMean)>,
#endif

};
Expand Down
Loading
Loading