Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed Expand #18126

Merged
merged 5 commits into from
Oct 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_slice.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_reshape.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_expand.cc"
)
endif()
# add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio
Expand Down
1 change: 1 addition & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ if (NOT onnxruntime_USE_NCCL)
list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_slice.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_reshape.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_expand.cc")
endif()

set(provider_excluded_files
Expand Down
110 changes: 110 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
Fixed Show fixed Hide fixed
// Licensed under the MIT License.

// Distributed computation.
#include "distributed_expand.h"

Check warning on line 5 in onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc#L5

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc:5:  Include the directory when naming header files  [build/include_subdir] [4]
#include "sharding.h"

Check warning on line 6 in onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc#L6

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc:6:  Include the directory when naming header files  [build/include_subdir] [4]
#include "sharding_spec.h"

Check warning on line 7 in onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc#L7

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc:7:  Include the directory when naming header files  [build/include_subdir] [4]
#include "nccl_kernels.h"

Check warning on line 8 in onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc#L8

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc:8:  Include the directory when naming header files  [build/include_subdir] [4]
#include "mpi_include.h"

Check warning on line 9 in onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc#L9

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc:9:  Include the directory when naming header files  [build/include_subdir] [4]

// ORT system.
#include "core/providers/cuda/tensor/expand.h"

// std C++.
#include <iostream>

Check warning on line 15 in onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc#L15

Found C++ system header after other header. Should be: distributed_expand.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_expand.cc:15:  Found C++ system header after other header. Should be: distributed_expand.h, c system, c++ system, other.  [build/include_order] [4]

namespace onnxruntime {
namespace contrib {
namespace cuda {

#if defined(ORT_USE_NCCL)

template <typename T>
DistributedExpand<T>::DistributedExpand(const OpKernelInfo& info) : DistributedKernel(info) {}

template <typename T>
Status DistributedExpand<T>::ComputeInternal(OpKernelContext* context) const {
ORT_ENFORCE(context != nullptr);
// Assumptions.
// - Shape is not sharded.
// Algorithm.
// - Compute logical output shape.
// - Compute local output shape.
// - Expand from local input to local output.

auto input_tensor = context->Input<Tensor>(0);
auto shape_tensor = context->Input<Tensor>(1);
const auto& input_sharding_spec = input_shard_specs_.at(0);
const auto& shape_sharding_spec = input_shard_specs_.at(1);
const auto& output_sharding_spec = output_shard_specs_.at(0);

ORT_ENFORCE(shape_sharding_spec.HasNoShard(),
"It's not worth to shard Shape tensor. "
"If sharding shape is needed, please submit a feature request.");
// Compute logical input shape.
const auto original_input_shape = ComputeOriginShape(input_tensor->Shape(), input_sharding_spec);

// Compute logical output shape.
// This `shape_tensor` stores the logical output shape.
const auto* p_shape = shape_tensor->Data<int64_t>();
TensorShapeVector original_output_dims{p_shape, p_shape + shape_tensor->Shape().Size()};
TensorShape original_output_shape(original_output_dims);
ORT_ENFORCE(
onnxruntime::cuda::ComputeOutputShape(
Node().Name(),
original_input_shape,
original_output_dims, original_output_shape)
.IsOK());

// Compute local output shape.
const auto local_output_shape = ComputeShardShape(original_output_shape, output_sharding_spec);

auto output_tensor = context->Output(0, local_output_shape);

return FuncExpand(
this,
context,
input_tensor,
shape_tensor,
output_tensor);
}

ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedExpand,
kMSDomain,
1,
int64_t,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedExpand<int64_t>);

ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedExpand,
kMSDomain,
1,
float,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedExpand<float>);

ONNX_OPERATOR_TYPED_KERNEL_EX(
DistributedExpand,
kMSDomain,
1,
MLFloat16,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>())
.InputMemoryType(OrtMemTypeCPUInput, 1),
DistributedExpand<MLFloat16>);

#endif

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
35 changes: 35 additions & 0 deletions onnxruntime/contrib_ops/cuda/collective/distributed_expand.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "sharding_spec.h"

Check warning on line 4 in onnxruntime/contrib_ops/cuda/collective/distributed_expand.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_expand.h#L4

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_expand.h:4:  Include the directory when naming header files  [build/include_subdir] [4]
#include "sharding.h"

Check warning on line 5 in onnxruntime/contrib_ops/cuda/collective/distributed_expand.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_expand.h#L5

Include the directory when naming header files [build/include_subdir] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_expand.h:5:  Include the directory when naming header files  [build/include_subdir] [4]
#include "core/providers/cuda/cuda_kernel.h"

#include <algorithm>

Check warning on line 8 in onnxruntime/contrib_ops/cuda/collective/distributed_expand.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_expand.h#L8

Found C++ system header after other header. Should be: distributed_expand.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_expand.h:8:  Found C++ system header after other header. Should be: distributed_expand.h, c system, c++ system, other.  [build/include_order] [4]
#include <tuple>

Check warning on line 9 in onnxruntime/contrib_ops/cuda/collective/distributed_expand.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_expand.h#L9

Found C++ system header after other header. Should be: distributed_expand.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_expand.h:9:  Found C++ system header after other header. Should be: distributed_expand.h, c system, c++ system, other.  [build/include_order] [4]
#include <optional>

Check warning on line 10 in onnxruntime/contrib_ops/cuda/collective/distributed_expand.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_expand.h#L10

Found C++ system header after other header. Should be: distributed_expand.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_expand.h:10:  Found C++ system header after other header. Should be: distributed_expand.h, c system, c++ system, other.  [build/include_order] [4]
#include <string>

Check warning on line 11 in onnxruntime/contrib_ops/cuda/collective/distributed_expand.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_expand.h#L11

Found C++ system header after other header. Should be: distributed_expand.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_expand.h:11:  Found C++ system header after other header. Should be: distributed_expand.h, c system, c++ system, other.  [build/include_order] [4]
#include <nccl.h>

Check warning on line 12 in onnxruntime/contrib_ops/cuda/collective/distributed_expand.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_expand.h#L12

Found C system header after other header. Should be: distributed_expand.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_expand.h:12:  Found C system header after other header. Should be: distributed_expand.h, c system, c++ system, other.  [build/include_order] [4]
#include <sstream>

Check warning on line 13 in onnxruntime/contrib_ops/cuda/collective/distributed_expand.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/collective/distributed_expand.h#L13

Found C++ system header after other header. Should be: distributed_expand.h, c system, c++ system, other. [build/include_order] [4]
Raw output
onnxruntime/contrib_ops/cuda/collective/distributed_expand.h:13:  Found C++ system header after other header. Should be: distributed_expand.h, c system, c++ system, other.  [build/include_order] [4]

#pragma once

namespace onnxruntime {
namespace contrib {
namespace cuda {

#if defined(ORT_USE_NCCL)

template <typename T>
class DistributedExpand final : public DistributedKernel {
public:
explicit DistributedExpand(const OpKernelInfo& info);

Status ComputeInternal(OpKernelContext* context) const override;
};

#endif

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
8 changes: 8 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReshape);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReshape);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReshape);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand);
#endif

template <>
Expand Down Expand Up @@ -344,6 +348,10 @@
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedReshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedReshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedReshape)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int64_t, DistributedExpand)>,

Check warning on line 352 in onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc#L352

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc:352:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DistributedExpand)>,

Check warning on line 353 in onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc#L353

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc:353:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedExpand)>,

Check warning on line 354 in onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc#L354

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc:354:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
#endif

};
Expand Down
37 changes: 37 additions & 0 deletions onnxruntime/core/graph/contrib_ops/collective_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,43 @@ void RegisterCollectiveOps() {
OpSchema::NonDifferentiable)
.Output(0, "reshaped", "Reshaped data.", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensor types.");

ONNX_CONTRIB_OPERATOR_SCHEMA(DistributedExpand)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Attr("input_device_mesh_elements",
"device_mesh_elements[i] defines the device mesh's value for the i-th input. "
"E.g., device_mesh_elements=[\"[0, 1]\", \"[0, 1]\"] means the 1st and the 2nd "
" inputs are stored on the 0-th and the 1st devices, respectively.",
AttributeProto::STRINGS)
.Attr("input_device_mesh_shapes",
"device_mesh_shape[i] defines the device mesh's shape for the i-th input.",
AttributeProto::STRINGS)
.Attr("input_shard_specs",
"The sharding spec of inputs. "
"E.g., if input_shard_specs[i] is \"RRR\", the i-th input is a unsharded 3-D tensor.",
AttributeProto::STRINGS)
.Attr("output_device_mesh_elements",
"Similar to input_device_mesh_elments but for outputs.",
AttributeProto::STRINGS)
.Attr("output_device_mesh_shapes",
"Similar to input_device_mesh_shapes but for outputs.",
AttributeProto::STRINGS)
.Attr("output_shard_specs",
"Similar to input_shard_specs but for outputs.",
AttributeProto::STRINGS)
.Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.Input(
1,
"shape",
"A 1-D tensor indicates the shape you want to expand to, following the broadcast rule",
"tensor(int64)",
OpSchema::Single,
true,
1,
OpSchema::NonDifferentiable)
.Output(0, "output", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
.TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Constrain input and output types to all tensors.");
}

} // namespace contrib
Expand Down
80 changes: 80 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/expand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,86 @@
input_strides);
}

Status FuncExpand(
const CudaKernel* cuda_kernel,
OpKernelContext* ctx,
const Tensor* input_data_tensor,
const Tensor* /*input_shape_tensor*/,
Tensor* output_tensor) {
TensorShape output_shape = output_tensor->Shape();

#ifdef ENABLE_STRIDED_TENSORS
// Strided output.
if (input_data_tensor->DataRaw() == output_tensor->DataRaw()) {
gsl::span<const int64_t> input_strides = input_data_tensor->Strides();
TensorShapeVector output_strides =
ComputeOutputStrides(input_data_tensor->Shape(), input_strides, output_shape);
output_tensor->SetShapeAndStrides(output_shape, output_strides);
return Status::OK();
}
#endif

auto output_dims = output_shape.AsShapeVector();
auto input_dims = input_data_tensor->Shape().AsShapeVector();

CalcEffectiveDims(input_dims, output_dims);
int rank = gsl::narrow_cast<int>(output_dims.size());

TensorPitches original_input_strides(input_dims);
TensorPitches original_output_strides(output_dims);

TArray<int64_t> input_strides(rank);
for (auto i = 0; i < rank; i++) {
input_strides[i] = input_dims[i] == 1 ? 0 : original_input_strides[i];
}

TArray<fast_divmod> output_strides(rank);
for (auto i = 0; i < rank; i++) {
output_strides[i] = fast_divmod(static_cast<int>(original_output_strides[i]));
}

return ExpandImpl(
cuda_kernel->Stream(ctx),
input_data_tensor->DataType()->Size(),
gsl::narrow_cast<int>(output_shape.Size()),
gsl::narrow_cast<int>(input_data_tensor->Shape().Size()),
input_data_tensor->DataRaw(),
output_tensor->MutableDataRaw(),
output_strides,
input_strides);
}

std::unique_ptr<Tensor> FuncExpand(

Check warning on line 194 in onnxruntime/core/providers/cuda/tensor/expand.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/tensor/expand.cc#L194

Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/cuda/tensor/expand.cc:194:  Add #include <memory> for unique_ptr<>  [build/include_what_you_use] [4]
const CudaKernel* cuda_kernel,
OpKernelContext* ctx,
const Tensor* input_data_tensor,
const Tensor* input_shape_tensor) {
// new shape to be expanded to
const auto* p_shape = input_shape_tensor->Data<int64_t>();
TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor->Shape().Size()};
TensorShape output_shape(output_dims);

ORT_ENFORCE(
ComputeOutputShape(
cuda_kernel->Node().Name(),
input_data_tensor->Shape(),
output_dims, output_shape)
.IsOK());

// Pre-allocate output.
AllocatorPtr alloc;
ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc).IsOK());
auto output_tensor = Tensor::Create(input_data_tensor->DataType(), output_shape, alloc);

// Only assign output values when output tensor is non-empty
// because empty tensor doesn't own any data.
if (output_shape.Size() > 0) {
ORT_ENFORCE(FuncExpand(cuda_kernel, ctx, input_data_tensor, input_shape_tensor, output_tensor.get()).IsOK());
}

return output_tensor;
}

#ifdef ENABLE_STRIDED_TENSORS
#define CREATE_EXPAND_KERNEL_DEF (*KernelDefBuilder::Create()).MayStridedOutput(0, 0)
#else
Expand Down
13 changes: 13 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/expand.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,18 @@
const TensorShape& rhs_shape,
TensorShape& out_shape);

Status FuncExpand(
const CudaKernel* cuda_kernel,
OpKernelContext* ctx,
const Tensor* input_data_tensor,
const Tensor* /*input_shape_tensor*/,
Tensor* output_tensor);

std::unique_ptr<Tensor> FuncExpand(

Check warning on line 30 in onnxruntime/core/providers/cuda/tensor/expand.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/tensor/expand.h#L30

Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/cuda/tensor/expand.h:30:  Add #include <memory> for unique_ptr<>  [build/include_what_you_use] [4]
const CudaKernel* cuda_kernel,
OpKernelContext* ctx,
const Tensor* input_data_tensor,
const Tensor* input_shape_tensor);

} // namespace cuda
} // namespace onnxruntime
Loading
Loading